|
| 1 | +import os |
| 2 | +import re |
| 3 | +import subprocess |
| 4 | +import time |
| 5 | +from datetime import datetime |
| 6 | + |
| 7 | +import requests |
| 8 | +import yaml |
| 9 | + |
| 10 | +API_KEY = os.environ["BASETEN_API_KEY"] |
| 11 | + |
| 12 | + |
| 13 | +def get_model_dir(): |
| 14 | + result = subprocess.run( |
| 15 | + ["git", "diff", "--name-only", "origin/main", "HEAD"], capture_output=True |
| 16 | + ) |
| 17 | + changed_files = result.stdout.decode().split("\n") |
| 18 | + print(changed_files) |
| 19 | + for file in changed_files: |
| 20 | + if re.match(r".*/model/model\.py", file): |
| 21 | + return os.path.dirname(os.path.dirname(file)) |
| 22 | + raise Exception("No model file found") |
| 23 | + |
| 24 | + |
| 25 | +def get_example_input(image_str): |
| 26 | + ls = os.listdir(".") |
| 27 | + if "config.yaml" not in ls: |
| 28 | + raise Exception("No config.yaml found. You must implement a config.yaml file.") |
| 29 | + with open("config.yaml", "r") as f: |
| 30 | + try: |
| 31 | + loaded_config = yaml.safe_load(f.read()) |
| 32 | + except yaml.YAMLError as e: |
| 33 | + raise Exception(f"Invalid config.yaml: {e}") |
| 34 | + |
| 35 | + if "model_metadata" not in loaded_config: |
| 36 | + raise Exception( |
| 37 | + "No model_metadata found in config.yaml. Config must include model_metadata with an example_model_input value." |
| 38 | + ) |
| 39 | + |
| 40 | + if "example_model_input" not in loaded_config["model_metadata"]: |
| 41 | + raise Exception("No example_model_input found in model_metadata") |
| 42 | + |
| 43 | + if "model_name" not in loaded_config: |
| 44 | + loaded_config["model_name"] = "model" |
| 45 | + with open("config.yaml", "w") as f: |
| 46 | + f.write(yaml.safe_dump(loaded_config)) |
| 47 | + |
| 48 | + example_input = loaded_config["model_metadata"]["example_model_input"] |
| 49 | + for key in ["image", "b64_image", "base64_image", "base_image", "input_image"]: |
| 50 | + if key in example_input: |
| 51 | + example_input[key] = image_str |
| 52 | + return example_input |
| 53 | + |
| 54 | + |
| 55 | +def truss_push(): |
| 56 | + print("Pushing model...") |
| 57 | + with open("/home/runner/.trussrc", "w") as config_file: |
| 58 | + config_file.write( |
| 59 | + f"""[baseten] |
| 60 | +remote_provider = baseten |
| 61 | +api_key = {API_KEY} |
| 62 | +remote_url = https://app.baseten.co""" |
| 63 | + ) |
| 64 | + |
| 65 | + result = subprocess.run(["truss", "push", "--trusted"], capture_output=True) |
| 66 | + match = re.search( |
| 67 | + r"View logs for your deployment at \n?https://app\.baseten\.co/models/(\w+)/logs/(\w+)", |
| 68 | + result.stdout.decode(), |
| 69 | + ) |
| 70 | + if not match: |
| 71 | + raise Exception( |
| 72 | + f"Failed to push model:\n\nSTDOUT: {result.stdout.decode()}\nSTDERR: {result.stderr.decode()}" |
| 73 | + ) |
| 74 | + model_id = match.group(1) |
| 75 | + deployment_id = match.group(2) |
| 76 | + print( |
| 77 | + f"Model pushed successfully. model-id: {model_id}. deployment-id: {deployment_id}" |
| 78 | + ) |
| 79 | + return model_id, deployment_id |
| 80 | + |
| 81 | + |
| 82 | +def truss_predict(model_id, input): |
| 83 | + result = {"error": "Model is not ready, it is still building or deploying"} |
| 84 | + seconds_remaining = 60 * 30 # Wait for 30 minutes |
| 85 | + while ( |
| 86 | + "error" in result |
| 87 | + and result["error"] == "Model is not ready, it is still building or deploying" |
| 88 | + and seconds_remaining > 0 |
| 89 | + ): |
| 90 | + print(f"{round(seconds_remaining / 60, 2)} minutes remaining") |
| 91 | + result = requests.post( |
| 92 | + f"https://model-{model_id}.api.baseten.co/development/predict", |
| 93 | + headers={"Authorization": f"Api-Key {API_KEY}"}, |
| 94 | + json=input, |
| 95 | + ) |
| 96 | + |
| 97 | + try: |
| 98 | + result = result.json() |
| 99 | + except requests.exceptions.JSONDecodeError as e: |
| 100 | + print(f"Failed to decode JSON: {e}") |
| 101 | + print(result.text) |
| 102 | + return |
| 103 | + |
| 104 | + if not isinstance(result, dict): |
| 105 | + return result |
| 106 | + |
| 107 | + seconds_remaining -= 30 |
| 108 | + print("Waiting for model to be ready...") |
| 109 | + time.sleep(30) |
| 110 | + |
| 111 | + return result |
| 112 | + |
| 113 | + |
| 114 | +def get_truss_logs(deployment_id, start_time): |
| 115 | + result = requests.post( |
| 116 | + "https://app.baseten.co/logs", |
| 117 | + headers={"Authorization": f"Api-Key {API_KEY}"}, |
| 118 | + json={ |
| 119 | + "type": "MODEL", |
| 120 | + "start": start_time, |
| 121 | + "end": get_time_in_ms(), |
| 122 | + "levels": [], |
| 123 | + "regex": "", |
| 124 | + "limit": 500, |
| 125 | + "entity_id": deployment_id, |
| 126 | + "direction": "backward", |
| 127 | + }, |
| 128 | + ) |
| 129 | + # {'success': False, 'message': 'Failed to load logs'} |
| 130 | + return result.json() |
| 131 | + |
| 132 | + |
| 133 | +def deactivate_truss(model_id): |
| 134 | + print("Deactivating model...") |
| 135 | + result = requests.post( |
| 136 | + f"https://api.baseten.co/v1/models/{model_id}/deployments/production/deactivate", |
| 137 | + headers={"Authorization": f"Api-Key {API_KEY}"}, |
| 138 | + ) |
| 139 | + print("Model deactivated successfully") |
| 140 | + print(result) |
| 141 | + |
| 142 | + |
| 143 | +def print_formatted_logs(logs): |
| 144 | + for log in reversed(logs): |
| 145 | + ts = datetime.fromtimestamp(int(log["ts"]) // 1_000_000_000).ctime() |
| 146 | + print(f"{ts} - {log['level']} - {log['msg']}") |
| 147 | + |
| 148 | + |
| 149 | +def get_time_in_ms(): |
| 150 | + return time.time_ns() // 1_000_000 |
| 151 | + |
| 152 | + |
| 153 | +if __name__ == "__main__": |
| 154 | + model_dir = get_model_dir() |
| 155 | + |
| 156 | + image_str = open("bin/image.txt", "r").read() |
| 157 | + |
| 158 | + os.chdir(model_dir) |
| 159 | + example_input = get_example_input(image_str) |
| 160 | + model_id, deployment_id = truss_push() |
| 161 | + start_time = get_time_in_ms() |
| 162 | + result = truss_predict(model_id, example_input) |
| 163 | + print(f"Model prediction result: {result}") |
| 164 | + if "error" in result: |
| 165 | + logs = get_truss_logs(deployment_id, start_time) |
| 166 | + if logs["success"]: |
| 167 | + print_formatted_logs(logs["logs"]) |
| 168 | + print( |
| 169 | + f"Failed to make prediction. Received error from model {result['error']}" |
| 170 | + ) |
| 171 | + exit(1) |
| 172 | + else: |
| 173 | + print("Failed to load logs") |
| 174 | + |
| 175 | + deactivate_truss(model_id) |
0 commit comments