Skip to content

Commit 234dd8f

Browse files
Fix stack traces on exit during script (#176)
1 parent 835c874 commit 234dd8f

File tree

5 files changed

+28
-15
lines changed

5 files changed

+28
-15
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "tabpfn-client"
7-
version = "0.2.7"
7+
version = "0.2.8"
88
requires-python = ">=3.9"
99
dynamic = ["dependencies", "optional-dependencies"]
1010

tabpfn_client/prompt_agent.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Prior Labs GmbH 2025.
22
# Licensed under the Apache License, Version 2.0
33
import getpass
4+
import sys
45
import textwrap
56
from rich.table import Table
67

@@ -17,6 +18,17 @@
1718
)
1819

1920

21+
def maybe_graceful_exit() -> None:
22+
try:
23+
from IPython import get_ipython
24+
25+
if get_ipython() is not None:
26+
return
27+
except ImportError:
28+
# We're in a script, just exit
29+
sys.exit(1)
30+
31+
2032
class PromptAgent:
2133
def __new__(cls):
2234
raise RuntimeError(
@@ -117,6 +129,7 @@ def prompt_and_set_token(cls) -> bool:
117129
return True
118130
except KeyboardInterrupt:
119131
console.print("\n\n[yellow]Interrupted. Goodbye![/yellow]")
132+
maybe_graceful_exit()
120133
return False
121134

122135
@classmethod
@@ -145,6 +158,7 @@ def _prompt_and_set_token_impl(cls) -> bool:
145158

146159
if choice == "q":
147160
console.print("Goodbye!")
161+
maybe_graceful_exit()
148162
return False
149163

150164
# Registration
@@ -347,6 +361,7 @@ def _prompt_and_set_token_impl(cls) -> bool:
347361
continue
348362
elif retry_choice == "q":
349363
console.print("Goodbye!")
364+
maybe_graceful_exit()
350365
return False
351366
else:
352367
# Invalid choice, use default (retry)
@@ -551,7 +566,8 @@ def reverify_email(cls, access_token):
551566
return "restart" # Signal to show main menu
552567
elif choice in ["q", "quit"]:
553568
console.print("Goodbye!")
554-
return False # Signal to quit without showing menu
569+
maybe_graceful_exit()
570+
return False
555571
else:
556572
warn("Please enter 1, 2, or q.")
557573

tabpfn_client/tests/quick_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from sklearn.model_selection import train_test_split
1414

1515
# from tabpfn_client import UserDataClient
16-
from tabpfn_client.constants import ModelVersion
1716
from tabpfn_client.estimator import TabPFNClassifier, TabPFNRegressor
17+
from tabpfn_client.constants import ModelVersion
1818

1919
logging.basicConfig(level=logging.INFO)
2020

tabpfn_client/tests/quick_test_v1.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import json
3-
import pandas as pd
43
import requests
54
from typing import Optional
65

@@ -9,20 +8,22 @@
98
TARGET_NAME = "target"
109

1110

12-
def call_fit(train_path: str, target_name: str = TARGET_NAME, api_key: Optional[str] = None) -> str:
11+
def call_fit(
12+
train_path: str, target_name: str = TARGET_NAME, api_key: Optional[str] = None
13+
) -> str:
1314
"""
1415
Call the /v1/fit endpoint to train a model.
15-
16+
1617
Args:
1718
train_path: Path to the training CSV file
1819
target_name: Name of the target column (default: "target")
1920
api_key: API key for authentication (if None, reads from PRIORLABS_API_KEY env var)
20-
21+
2122
Returns:
2223
model_id: The model ID returned from the fit endpoint
2324
"""
2425
headers = {"Authorization": f"Bearer {api_key}"}
25-
26+
2627
payload = {
2728
"task": "classification",
2829
"schema": {
@@ -52,14 +53,14 @@ def call_fit(train_path: str, target_name: str = TARGET_NAME, api_key: Optional[
5253
def call_predict(test_path: str, model_id: str, api_key: Optional[str] = None) -> None:
5354
"""
5455
Call the /v1/predict endpoint to get predictions.
55-
56+
5657
Args:
5758
test_path: Path to the test CSV file
5859
model_id: The model ID from the fit call
5960
api_key: API key for authentication (if None, reads from PRIORLABS_API_KEY env var)
6061
"""
6162
headers = {"Authorization": f"Bearer {api_key}"}
62-
63+
6364
payload = {
6465
"task": "classification",
6566
"model_id": model_id,
@@ -99,7 +100,7 @@ def main() -> None:
99100
# Get the API key
100101
api_key = os.getenv("PRIORLABS_API_KEY")
101102

102-
#Test /v1/fit
103+
# Test /v1/fit
103104
print("--- Testing /v1/fit ---")
104105
model_id = call_fit(train_path, api_key=api_key)
105106

tabpfn_client/ui.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from contextlib import contextmanager
99
from typing import Iterator, List
1010

11-
from rich import traceback
1211
from rich.console import Console
1312
from rich.logging import RichHandler
1413
from rich.panel import Panel
@@ -33,9 +32,6 @@ def _collect_suppressed_modules() -> List[object]:
3332
return suppressed
3433

3534

36-
traceback.install(show_locals=False, suppress=_collect_suppressed_modules())
37-
38-
3935
def _should_use_color() -> bool:
4036
"""Determine whether color output should be used."""
4137

0 commit comments

Comments
 (0)