Skip to content

Commit 835c874

Browse files
authored
Push changes ahead of the release (#175)
* Push changes ahead of the release * Fix unit test
1 parent be2c2c0 commit 835c874

File tree

4 files changed

+115
-3
lines changed

4 files changed

+115
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "tabpfn-client"
7-
version = "0.2.5"
7+
version = "0.2.7"
88
requires-python = ">=3.9"
99
dynamic = ["dependencies", "optional-dependencies"]
1010

tabpfn_client/server_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# production
77
protocol: "https"
8-
host: "tabpfn-server-wjedmz7r5a-ez.a.run.app"
8+
host: "api.priorlabs.ai"
99
port: "443"
1010
gui_url: "https://ux.priorlabs.ai"
1111
endpoints:

tabpfn_client/tests/integration/test_server_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def setUp(self):
1313
self.config = yaml.safe_load(f)
1414

1515
def test_host_configuration(self):
16-
expected_host = "tabpfn-server-wjedmz7r5a-ez.a.run.app"
16+
expected_host = "api.priorlabs.ai"
1717
self.assertEqual(
1818
self.config["host"], expected_host, f"Host should be {expected_host}"
1919
)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
import json
3+
import pandas as pd
4+
import requests
5+
from typing import Optional
6+
7+
from tabpfn_client.tabpfn_common_utils.utils import get_example_dataset
8+
9+
TARGET_NAME = "target"
10+
11+
12+
def call_fit(train_path: str, target_name: str = TARGET_NAME, api_key: Optional[str] = None) -> str:
13+
"""
14+
Call the /v1/fit endpoint to train a model.
15+
16+
Args:
17+
train_path: Path to the training CSV file
18+
target_name: Name of the target column (default: "target")
19+
api_key: API key for authentication (if None, reads from PRIORLABS_API_KEY env var)
20+
21+
Returns:
22+
model_id: The model ID returned from the fit endpoint
23+
"""
24+
headers = {"Authorization": f"Bearer {api_key}"}
25+
26+
payload = {
27+
"task": "classification",
28+
"schema": {
29+
"target": target_name,
30+
"description": "Iris dataset (quick test)",
31+
},
32+
}
33+
files = {
34+
"data": (None, json.dumps(payload), "application/json"),
35+
"dataset_file": (train_path, open(train_path, "rb")),
36+
}
37+
response = requests.post(
38+
"http://localhost:8000/v1/fit",
39+
headers=headers,
40+
files=files,
41+
)
42+
if response.status_code != 200:
43+
raise RuntimeError(f"[FIT] HTTP {response.status_code}: {response.text}")
44+
res_j = response.json()
45+
model_id = res_j.get("model_id")
46+
if not model_id:
47+
raise RuntimeError(f"[FIT] No model_id in response: {res_j}")
48+
print(f"✅ Model trained: {model_id}")
49+
return model_id
50+
51+
52+
def call_predict(test_path: str, model_id: str, api_key: Optional[str] = None) -> None:
53+
"""
54+
Call the /v1/predict endpoint to get predictions.
55+
56+
Args:
57+
test_path: Path to the test CSV file
58+
model_id: The model ID from the fit call
59+
api_key: API key for authentication (if None, reads from PRIORLABS_API_KEY env var)
60+
"""
61+
headers = {"Authorization": f"Bearer {api_key}"}
62+
63+
payload = {
64+
"task": "classification",
65+
"model_id": model_id,
66+
}
67+
files = {
68+
"data": (None, json.dumps(payload), "application/json"),
69+
"file": (test_path, open(test_path, "rb")),
70+
}
71+
response = requests.post(
72+
"http://localhost:8000/v1/predict",
73+
headers=headers,
74+
files=files,
75+
)
76+
if response.status_code != 200:
77+
raise RuntimeError(f"[PREDICT] HTTP {response.status_code}: {response.text}")
78+
print("✅ Predictions:")
79+
print(json.dumps(response.json(), indent=2))
80+
81+
82+
def main() -> None:
83+
"""Main function to generate dataset and test both endpoints."""
84+
# === Generate/train/test data as in quick_test.py ===
85+
x_train, x_test, y_train, y_test = get_example_dataset("iris")
86+
87+
train_df = x_train.copy()
88+
train_df[TARGET_NAME] = y_train.values
89+
90+
test_df = x_test.copy()
91+
test_df[TARGET_NAME] = y_test.values
92+
93+
train_path = "train.csv"
94+
test_path = "test.csv"
95+
96+
train_df.to_csv(train_path, index=False)
97+
test_df.to_csv(test_path, index=False)
98+
99+
# Get the API key
100+
api_key = os.getenv("PRIORLABS_API_KEY")
101+
102+
#Test /v1/fit
103+
print("--- Testing /v1/fit ---")
104+
model_id = call_fit(train_path, api_key=api_key)
105+
106+
# Test /v1/predict
107+
print("\n--- Testing /v1/predict ---")
108+
call_predict(test_path, model_id, api_key=api_key)
109+
110+
111+
if __name__ == "__main__":
112+
main()

0 commit comments

Comments
 (0)