Skip to content

Commit e772189

Browse files
authored
Merge pull request #54 from PriorLabs/track-model-fit-mode
Track `fit_mode` and general model init params
2 parents 3a11bde + 423059d commit e772189

File tree

4 files changed

+49
-2
lines changed

4 files changed

+49
-2
lines changed

src/tabpfn_common_utils/telemetry/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
set_extension,
88
get_current_extension,
99
set_model_config,
10+
set_init_params,
11+
get_init_params,
1012
)
1113

1214
# Public exports
@@ -21,4 +23,6 @@
2123
"set_extension",
2224
"get_current_extension",
2325
"set_model_config",
26+
"set_init_params",
27+
"get_init_params",
2428
]

src/tabpfn_common_utils/telemetry/core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
set_extension,
1717
get_current_extension,
1818
set_model_config,
19+
set_init_params,
20+
get_init_params,
1921
)
2022

2123
# Public exports
@@ -32,4 +34,6 @@
3234
"set_extension",
3335
"get_current_extension",
3436
"set_model_config",
37+
"set_init_params",
38+
"get_init_params",
3539
]

src/tabpfn_common_utils/telemetry/core/decorators.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dataclasses import dataclass
1414
from pathlib import Path
1515
from functools import wraps
16-
from typing import Any, Callable, Literal, Optional, Tuple, Union
16+
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
1717

1818
from .events import FitEvent, PredictEvent
1919
from .service import capture_event
@@ -85,6 +85,38 @@ def get_model_config() -> Optional[Tuple[str, str]]:
8585
return None
8686

8787

88+
def set_init_params(
89+
params: Dict[str, Any],
90+
) -> Optional[contextvars.Token[Optional[str]]]:
91+
"""Set the initial parameters of the model.
92+
93+
Args:
94+
params: The initial parameters of the model.
95+
"""
96+
try:
97+
token = json.dumps(params)
98+
tok = _get_context_var("tabpfn_model_init_params").set(token)
99+
return tok
100+
except Exception:
101+
return None
102+
103+
104+
def get_init_params() -> Optional[Dict[str, Any]]:
105+
"""Get the initial parameters of the model.
106+
107+
Returns:
108+
The initial parameters of the model.
109+
"""
110+
token = _get_context_var("tabpfn_model_init_params").get()
111+
if token is None:
112+
return None
113+
114+
try:
115+
return json.loads(token)
116+
except Exception:
117+
return None
118+
119+
88120
def get_current_extension() -> Optional[str]:
89121
"""Get the current extension.
90122
@@ -383,6 +415,10 @@ def _send_model_called_event(call_info: _ModelCallInfo, duration_ms: int) -> Non
383415
event.model_path = model_path
384416
event.model_version = model_version
385417

418+
# Set the model init params for fit
419+
if isinstance(event, FitEvent):
420+
event.init_params = get_init_params()
421+
386422
except TypeError as e:
387423
logger.debug(f"Event creation failed: {e}")
388424
return

src/tabpfn_common_utils/telemetry/core/events.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, asdict, field
66
from datetime import datetime, timezone
77
from functools import lru_cache
8-
from typing import Any, Literal, Optional
8+
from typing import Any, Dict, Literal, Optional
99
from .runtime import get_execution_context
1010
from .state import get_property, set_property
1111

@@ -373,6 +373,9 @@ class FitEvent(ModelCallEvent):
373373
Event emitted when a model is fit.
374374
"""
375375

376+
# Initial parameters of the model
377+
init_params: Optional[Dict[str, Any]] = field(default=None, init=False)
378+
376379
@property
377380
def name(self) -> str:
378381
return "fit_called"

0 commit comments

Comments
 (0)