Skip to content

Commit 43da6dc

Browse files
Droid github action (#245)
Add a github action that gets triggered on droid PRs. This will push trusses and run predictions to see if the truss works correctly. If the test fails a comment gets left on the PR that droid then responds to
1 parent ed12b4a commit 43da6dc

File tree

3 files changed

+210
-0
lines changed

3 files changed

+210
-0
lines changed

.github/workflows/truss_deploy.yml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: Droid deploy truss
2+
3+
on:
4+
pull_request:
5+
types: [opened, synchronize, reopened, ready_for_review]
6+
branches:
7+
- main
8+
paths:
9+
- '**/model.py'
10+
11+
jobs:
12+
truss_test:
13+
if: ${{ github.triggering_actor == 'factory-droid-dev' || github.triggering_actor == 'factory-droid-dev[bot]'}}
14+
runs-on: ubuntu-latest
15+
steps:
16+
- name: Checkout code
17+
uses: actions/checkout@v4
18+
with:
19+
fetch-depth: 0
20+
21+
- name: Set up Python
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: 3.11
25+
26+
- name: Install dependencies (if any)
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install git+https://github.com/basetenlabs/truss.git requests tenacity --upgrade
30+
31+
- name: Run tests
32+
env:
33+
BASETEN_API_KEY: ${{ secrets.BASETEN_API_KEY }}
34+
run: python bin/test_truss_deploy.py

bin/image.txt

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

bin/test_truss_deploy.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import os
2+
import re
3+
import subprocess
4+
import time
5+
from datetime import datetime
6+
7+
import requests
8+
import yaml
9+
10+
API_KEY = os.environ["BASETEN_API_KEY"]
11+
12+
13+
def get_model_dir():
14+
result = subprocess.run(
15+
["git", "diff", "--name-only", "origin/main", "HEAD"], capture_output=True
16+
)
17+
changed_files = result.stdout.decode().split("\n")
18+
print(changed_files)
19+
for file in changed_files:
20+
if re.match(r".*/model/model\.py", file):
21+
return os.path.dirname(os.path.dirname(file))
22+
raise Exception("No model file found")
23+
24+
25+
def get_example_input(image_str):
26+
ls = os.listdir(".")
27+
if "config.yaml" not in ls:
28+
raise Exception("No config.yaml found. You must implement a config.yaml file.")
29+
with open("config.yaml", "r") as f:
30+
try:
31+
loaded_config = yaml.safe_load(f.read())
32+
except yaml.YAMLError as e:
33+
raise Exception(f"Invalid config.yaml: {e}")
34+
35+
if "model_metadata" not in loaded_config:
36+
raise Exception(
37+
"No model_metadata found in config.yaml. Config must include model_metadata with an example_model_input value."
38+
)
39+
40+
if "example_model_input" not in loaded_config["model_metadata"]:
41+
raise Exception("No example_model_input found in model_metadata")
42+
43+
if "model_name" not in loaded_config:
44+
loaded_config["model_name"] = "model"
45+
with open("config.yaml", "w") as f:
46+
f.write(yaml.safe_dump(loaded_config))
47+
48+
example_input = loaded_config["model_metadata"]["example_model_input"]
49+
for key in ["image", "b64_image", "base64_image", "base_image", "input_image"]:
50+
if key in example_input:
51+
example_input[key] = image_str
52+
return example_input
53+
54+
55+
def truss_push():
56+
print("Pushing model...")
57+
with open("/home/runner/.trussrc", "w") as config_file:
58+
config_file.write(
59+
f"""[baseten]
60+
remote_provider = baseten
61+
api_key = {API_KEY}
62+
remote_url = https://app.baseten.co"""
63+
)
64+
65+
result = subprocess.run(["truss", "push", "--trusted"], capture_output=True)
66+
match = re.search(
67+
r"View logs for your deployment at \n?https://app\.baseten\.co/models/(\w+)/logs/(\w+)",
68+
result.stdout.decode(),
69+
)
70+
if not match:
71+
raise Exception(
72+
f"Failed to push model:\n\nSTDOUT: {result.stdout.decode()}\nSTDERR: {result.stderr.decode()}"
73+
)
74+
model_id = match.group(1)
75+
deployment_id = match.group(2)
76+
print(
77+
f"Model pushed successfully. model-id: {model_id}. deployment-id: {deployment_id}"
78+
)
79+
return model_id, deployment_id
80+
81+
82+
def truss_predict(model_id, input):
83+
result = {"error": "Model is not ready, it is still building or deploying"}
84+
seconds_remaining = 60 * 30 # Wait for 30 minutes
85+
while (
86+
"error" in result
87+
and result["error"] == "Model is not ready, it is still building or deploying"
88+
and seconds_remaining > 0
89+
):
90+
print(f"{round(seconds_remaining / 60, 2)} minutes remaining")
91+
result = requests.post(
92+
f"https://model-{model_id}.api.baseten.co/development/predict",
93+
headers={"Authorization": f"Api-Key {API_KEY}"},
94+
json=input,
95+
)
96+
97+
try:
98+
result = result.json()
99+
except requests.exceptions.JSONDecodeError as e:
100+
print(f"Failed to decode JSON: {e}")
101+
print(result.text)
102+
return
103+
104+
if not isinstance(result, dict):
105+
return result
106+
107+
seconds_remaining -= 30
108+
print("Waiting for model to be ready...")
109+
time.sleep(30)
110+
111+
return result
112+
113+
114+
def get_truss_logs(deployment_id, start_time):
115+
result = requests.post(
116+
"https://app.baseten.co/logs",
117+
headers={"Authorization": f"Api-Key {API_KEY}"},
118+
json={
119+
"type": "MODEL",
120+
"start": start_time,
121+
"end": get_time_in_ms(),
122+
"levels": [],
123+
"regex": "",
124+
"limit": 500,
125+
"entity_id": deployment_id,
126+
"direction": "backward",
127+
},
128+
)
129+
# {'success': False, 'message': 'Failed to load logs'}
130+
return result.json()
131+
132+
133+
def deactivate_truss(model_id):
134+
print("Deactivating model...")
135+
result = requests.post(
136+
f"https://api.baseten.co/v1/models/{model_id}/deployments/production/deactivate",
137+
headers={"Authorization": f"Api-Key {API_KEY}"},
138+
)
139+
print("Model deactivated successfully")
140+
print(result)
141+
142+
143+
def print_formatted_logs(logs):
144+
for log in reversed(logs):
145+
ts = datetime.fromtimestamp(int(log["ts"]) // 1_000_000_000).ctime()
146+
print(f"{ts} - {log['level']} - {log['msg']}")
147+
148+
149+
def get_time_in_ms():
150+
return time.time_ns() // 1_000_000
151+
152+
153+
if __name__ == "__main__":
154+
model_dir = get_model_dir()
155+
156+
image_str = open("bin/image.txt", "r").read()
157+
158+
os.chdir(model_dir)
159+
example_input = get_example_input(image_str)
160+
model_id, deployment_id = truss_push()
161+
start_time = get_time_in_ms()
162+
result = truss_predict(model_id, example_input)
163+
print(f"Model prediction result: {result}")
164+
if "error" in result:
165+
logs = get_truss_logs(deployment_id, start_time)
166+
if logs["success"]:
167+
print_formatted_logs(logs["logs"])
168+
print(
169+
f"Failed to make prediction. Received error from model {result['error']}"
170+
)
171+
exit(1)
172+
else:
173+
print("Failed to load logs")
174+
175+
deactivate_truss(model_id)

0 commit comments

Comments
 (0)