|
| 1 | +#Copyright (c) 2024 Oracle and/or its affiliates. |
| 2 | +# |
| 3 | +#The Universal Permissive License (UPL), Version 1.0 |
| 4 | +# |
| 5 | +#Subject to the condition set forth below, permission is hereby granted to any |
| 6 | +#person obtaining a copy of this software, associated documentation and/or data |
| 7 | +#(collectively the "Software"), free of charge and under any and all copyright |
| 8 | +#rights in the Software, and any and all patent rights owned or freely |
| 9 | +#licensable by each licensor hereunder covering either (i) the unmodified |
| 10 | +#Software as contributed to or provided by such licensor, or (ii) the Larger |
| 11 | +#Works (as defined below), to deal in both |
| 12 | +# |
| 13 | +#(a) the Software, and |
| 14 | +#(b) any piece of software and/or hardware listed in the lrgrwrks.txt file if |
| 15 | +#one is included with the Software (each a "Larger Work" to which the Software |
| 16 | +#is contributed by such licensors), |
| 17 | +# |
| 18 | +#without restriction, including without limitation the rights to copy, create |
| 19 | +#derivative works of, display, perform, and distribute the Software and make, |
| 20 | +#use, sell, offer for sale, import, export, have made, and have sold the |
| 21 | +#Software and the Larger Work(s), and to sublicense the foregoing rights on |
| 22 | +#either these or other terms. |
| 23 | +# |
| 24 | +#This license is subject to the following condition: |
| 25 | +#The above copyright notice and either this complete permission notice or at |
| 26 | +#a minimum a reference to the UPL must be included in all copies or |
| 27 | +#substantial portions of the Software. |
| 28 | +# |
| 29 | +#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 30 | +#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 31 | +#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 32 | +#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 33 | +#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 34 | +#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 35 | +#SOFTWARE. |
| 36 | + |
| 37 | +import time |
| 38 | +import torch |
| 39 | +import pynvml |
| 40 | +from fastapi import FastAPI, HTTPException |
| 41 | +from pydantic import BaseModel |
| 42 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 43 | + |
| 44 | +app = FastAPI() |
| 45 | + |
| 46 | +def print_vram_usage(): |
| 47 | + """Prints the VRAM usage of the GPU.""" |
| 48 | + pynvml.nvmlInit() |
| 49 | + handle = pynvml.nvmlDeviceGetHandleByIndex(0) |
| 50 | + info = pynvml.nvmlDeviceGetMemoryInfo(handle) |
| 51 | + print(f"Total VRAM: {info.total / 1024**2:.2f} MB") |
| 52 | + print(f"Free VRAM: {info.free / 1024**2:.2f} MB") |
| 53 | + print(f"Used VRAM: {info.used / 1024**2:.2f} MB") |
| 54 | + |
| 55 | +# Define the start time for the total initialization |
| 56 | +start_time = time.time() |
| 57 | + |
| 58 | +print("Loading tokenizer and model...") |
| 59 | + |
| 60 | +# Load the tokenizer and model from the local directory |
| 61 | +model_name = "/share/app/llama3-8b-instruct/" |
| 62 | + |
| 63 | +# Use torch.bfloat16 to reduce memory usage |
| 64 | +model_loading_start_time = time.time() |
| 65 | +model = AutoModelForCausalLM.from_pretrained( |
| 66 | + model_name, |
| 67 | + torch_dtype=torch.bfloat16 |
| 68 | +) |
| 69 | +print(f"Model loaded in {time.time() - model_loading_start_time:.2f} seconds.") |
| 70 | +print_vram_usage() |
| 71 | + |
| 72 | +tokenizer_loading_start_time = time.time() |
| 73 | +tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 74 | +print(f"Tokenizer loaded in {time.time() - tokenizer_loading_start_time:.2f} seconds.") |
| 75 | +print_vram_usage() |
| 76 | + |
| 77 | +# Use Data Parallel to utilize multiple GPUs |
| 78 | +if torch.cuda.device_count() > 1: |
| 79 | + print(f"Using {torch.cuda.device_count()} GPUs") |
| 80 | + model = torch.nn.DataParallel(model) |
| 81 | + |
| 82 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 83 | +model.to(device) |
| 84 | + |
| 85 | +print(f"Model moved to GPUs in {time.time() - model_loading_start_time:.2f} seconds.") |
| 86 | +print_vram_usage() |
| 87 | + |
| 88 | +class InferenceRequest(BaseModel): |
| 89 | + input_text: str |
| 90 | + |
| 91 | +@app.post("/infer") |
| 92 | +async def infer(request: InferenceRequest): |
| 93 | + try: |
| 94 | + inference_start_time = time.time() |
| 95 | + |
| 96 | + # Tokenize input |
| 97 | + inputs = tokenizer(request.input_text, return_tensors="pt") |
| 98 | + inputs = {key: value.to(device) for key, value in inputs.items()} |
| 99 | + |
| 100 | + # Inference |
| 101 | + with torch.no_grad(): |
| 102 | + outputs = model.module.generate(**inputs) if torch.cuda.device_count() > 1 else model.generate(**inputs) |
| 103 | + |
| 104 | + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| 105 | + |
| 106 | + print(f"Generated text: {generated_text}") |
| 107 | + print(f"Inference completed in {time.time() - inference_start_time:.2f} seconds.") |
| 108 | + |
| 109 | + return {"generated_text": generated_text} |
| 110 | + except Exception as e: |
| 111 | + print(f"Error during inference: {e}") |
| 112 | + raise HTTPException(status_code=500, detail=str(e)) |
| 113 | + |
| 114 | +if __name__ == "__main__": |
| 115 | + import uvicorn |
| 116 | + print(f"Total initialization time: {time.time() - start_time:.2f} seconds.") |
| 117 | + uvicorn.run(app, host="0.0.0.0", port=8000) |
0 commit comments