Skip to content

Commit 3a11bde

Browse files
authored
Merge pull request #56 from PriorLabs/ENG-377
Keep track of execution environment
2 parents a927b09 + 0b3d906 commit 3a11bde

File tree

5 files changed

+145
-51
lines changed

5 files changed

+145
-51
lines changed

src/tabpfn_common_utils/telemetry/core/events.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import datetime, timezone
77
from functools import lru_cache
88
from typing import Any, Literal, Optional
9-
from .runtime import get_runtime
9+
from .runtime import get_execution_context
1010
from .state import get_property, set_property
1111

1212

@@ -179,8 +179,8 @@ def _get_runtime_kernel() -> Optional[str]:
179179
Returns:
180180
str: Runtime environment of the platform.
181181
"""
182-
runtime = get_runtime()
183-
return runtime.kernel
182+
exec_context = get_execution_context()
183+
return exec_context.kernel
184184

185185

186186
@lru_cache(maxsize=1)

src/tabpfn_common_utils/telemetry/core/runtime.py

Lines changed: 131 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,65 +5,159 @@
55
import os
66
import sys
77
from dataclasses import dataclass
8-
from typing import Literal
8+
from typing import Callable, Dict, Literal, Mapping, Optional, Sequence, Tuple
9+
10+
11+
# The type of kernel the code is running in
12+
KernelType = Literal["ipython", "jupyter", "tty"]
13+
14+
# The type of environment the code is running in
15+
EnvironmentType = Literal[
16+
"kaggle",
17+
"colab",
18+
"gcp",
19+
"aws",
20+
"azure",
21+
"databricks",
22+
]
23+
24+
# Static list of environment hints, purely heuristic based on env variables.
25+
# This information is used purely for detecting purposes, and the values are
26+
# not propagated or stored in the telemetry state.
27+
ENV_TYPE_HINTS: Mapping[EnvironmentType, Sequence[str]] = {
28+
# Notebook providers
29+
"kaggle": [
30+
# Kaggle kernels
31+
"KAGGLE_KERNEL_RUN_TYPE",
32+
"KAGGLE_URL_BASE",
33+
"KAGGLE_KERNEL_INTEGRATIONS",
34+
"KAGGLE_USER_SECRETS_TOKEN",
35+
"KAGGLE_GCP_PROJECT",
36+
"KAGGLE_GCP_ZONE",
37+
],
38+
"colab": [
39+
# Google Colab
40+
"COLAB_GPU",
41+
"COLAB_TPU_ADDR",
42+
"COLAB_JUPYTER_TRANSPORT",
43+
"COLAB_BACKEND_VERSION",
44+
],
45+
"databricks": [
46+
# Databricks clusters and runtime
47+
"DATABRICKS_RUNTIME_VERSION",
48+
"DATABRICKS_CLUSTER_ID",
49+
"DATABRICKS_HOST",
50+
"DATABRICKS_WORKSPACE_URL",
51+
"DB_IS_DRIVER",
52+
],
53+
# Cloud providers
54+
"aws": [
55+
# Generic AWS environment hints
56+
"AWS_EXECUTION_ENV",
57+
"AWS_REGION",
58+
"AWS_DEFAULT_REGION",
59+
# SageMaker
60+
"SM_MODEL_DIR",
61+
"SM_NUM_GPUS",
62+
"SM_HOSTS",
63+
"SM_CURRENT_HOST",
64+
"TRAINING_JOB_NAME",
65+
],
66+
"gcp": [
67+
# Project hints
68+
"GOOGLE_CLOUD_PROJECT",
69+
"GCP_PROJECT",
70+
"GCLOUD_PROJECT",
71+
"CLOUDSDK_CORE_PROJECT",
72+
# Cloud Run and Cloud Functions
73+
"K_SERVICE",
74+
"K_REVISION",
75+
"K_CONFIGURATION",
76+
"CLOUD_RUN_JOB",
77+
# Vertex AI
78+
"AIP_MODEL_DIR",
79+
"AIP_DATA_FORMAT",
80+
"AIP_TRAINING_DATA_URI",
81+
"CLOUD_ML_JOB_ID",
82+
# Cloud Shell
83+
"CLOUD_SHELL",
84+
],
85+
"azure": [
86+
# Azure ML
87+
"AZUREML_RUN_ID",
88+
"AZUREML_ARM_SUBSCRIPTION",
89+
"AZUREML_ARM_RESOURCEGROUP",
90+
"AZUREML_ARM_WORKSPACE_NAME",
91+
],
92+
}
993

1094

1195
@dataclass
12-
class Runtime:
13-
"""Runtime environment."""
96+
class ExecutionContext:
97+
"""The execution context of the current environment."""
1498

1599
interactive: bool
16-
kernel: Literal["ipython", "jupyter", "tty", "kaggle"] | None = None
100+
"""Whether the code is running in an interactive environment."""
101+
102+
kernel: Optional[KernelType] = None
103+
"""Low-level Python frontend or shell (e.g. IPython, Jupyter, TTY)"""
104+
105+
environment: Optional[EnvironmentType] = None
106+
"""Higher-level hosted environment or notebook platform."""
107+
17108
ci: bool = False
109+
"""Whether the code is running in a CI environment."""
18110

19111

20-
def get_runtime() -> Runtime:
21-
"""Get the runtime environment.
112+
def get_execution_context() -> ExecutionContext:
113+
"""Get the execution context of the current environment.
22114
23115
Returns:
24-
The runtime environment.
116+
The execution context of the current environment.
25117
"""
26-
# First check for Kaggle
27-
if _is_kaggle():
28-
return Runtime(interactive=True, kernel="kaggle")
118+
# First check for environment
119+
environment = _get_environment()
120+
121+
# Next, get kernel information
122+
interactive, kernel = _get_kernel()
29123

30-
# Next check for CI
31-
if _is_ci():
32-
return Runtime(interactive=False, kernel=None, ci=True)
124+
context = ExecutionContext(
125+
interactive=interactive, kernel=kernel, environment=environment, ci=_is_ci()
126+
)
127+
return context
33128

34-
# Check for IPython
35-
if _is_ipy():
36-
return Runtime(interactive=True, kernel="ipython")
37129

38-
# Jupyter kernel
39-
if _is_jupyter_kernel():
40-
return Runtime(interactive=True, kernel="jupyter")
130+
def _get_kernel() -> Tuple[bool, Optional[KernelType]]:
131+
"""Get the kernel the code is running in.
41132
42-
# TTY
43-
if _is_tty():
44-
return Runtime(interactive=True, kernel="tty")
133+
Returns:
134+
A tuple of (whether the kernel is interactive, the kernel type).
135+
"""
136+
mapping: Dict[KernelType, Callable[[], bool]] = {
137+
"ipython": _is_ipy,
138+
"jupyter": _is_jupyter_kernel,
139+
"tty": _is_tty,
140+
}
141+
for kernel, func in mapping.items():
142+
if func():
143+
return True, kernel
144+
return False, None
45145

46-
# Default to non-interactive
47-
return Runtime(interactive=False, kernel=None)
48146

147+
def _get_environment() -> Optional[EnvironmentType]:
148+
"""Get the environment the code is running in.
49149
50-
def _is_kaggle() -> bool:
51-
"""Check if the current environment is running in a Kaggle kernel.
150+
An environment is a higher-level hosted environment or notebook platform.
151+
This is about where the code is running (Kaggle, Colab, AWS, GCP, ...).
52152
53153
Returns:
54-
bool: True if the current environment is running in a Kaggle kernel.
154+
The environment the code is running in.
55155
"""
56-
# Kaggle-specific and preset env vars
57-
kaggle_env_vars = [
58-
"KAGGLE_KERNEL_RUN_TYPE",
59-
"KAGGLE_URL_BASE",
60-
"KAGGLE_KERNEL_INTEGRATIONS",
61-
"KAGGLE_USER_SECRETS_TOKEN",
62-
]
63-
if any(v in os.environ for v in kaggle_env_vars):
64-
return True
156+
for env_type, hints in ENV_TYPE_HINTS.items():
157+
if any(k in os.environ for k in hints):
158+
return env_type
65159

66-
return False
160+
return None
67161

68162

69163
def _is_ipy() -> bool:

src/tabpfn_common_utils/telemetry/core/service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from posthog import Posthog
66
from .config import download_config
77
from .events import BaseTelemetryEvent
8-
from .runtime import get_runtime
8+
from .runtime import get_execution_context
99
from ...utils import singleton
1010
from typing import Any, Dict, Optional
1111

@@ -71,8 +71,8 @@ def telemetry_enabled(cls) -> bool:
7171
bool: True if telemetry is enabled, False otherwise.
7272
"""
7373
# Disable telemetry by default in CI environments, but allow override
74-
runtime = get_runtime()
75-
default_disable = "1" if runtime.ci else "0"
74+
exec_context = get_execution_context()
75+
default_disable = "1" if exec_context.ci else "0"
7676

7777
disable_telemetry = os.getenv(
7878
"TABPFN_DISABLE_TELEMETRY", default_disable

src/tabpfn_common_utils/telemetry/interactive/flows.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .prompts.base import PromptSpec
1212
from .prompts.newsletter import NewsletterPrompt
1313
from .prompts.identity import IdentityPrompt
14-
from ..core.runtime import get_runtime
14+
from ..core.runtime import get_execution_context
1515

1616

1717
def capture_session(enabled: bool = True) -> None:
@@ -142,8 +142,8 @@ def opt_in(enabled: bool = True, delta_days: int = 30, max_prompts: int = 2) ->
142142
return
143143

144144
# Only show prompts in Jupyter/IPython
145-
runtime = get_runtime()
146-
if runtime.kernel not in {"jupyter", "ipython"}:
145+
exec_context = get_execution_context()
146+
if exec_context.kernel not in {"jupyter", "ipython"}:
147147
return
148148

149149
# Check if prompts should be shown

tests/telemetry/core/test_runtime.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
_is_ipy,
1111
_is_jupyter_kernel,
1212
_is_tty,
13-
get_runtime,
13+
get_execution_context,
1414
)
1515

1616

@@ -24,7 +24,7 @@ def setup(self) -> None:
2424
def test_get_runtime_ci_environment(self) -> None:
2525
"""Test that CI environments are detected correctly."""
2626
with patch(f"{self.module}._is_ci", return_value=True):
27-
runtime = get_runtime()
27+
runtime = get_execution_context()
2828
assert runtime.ci is True
2929
assert runtime.interactive is False
3030
assert runtime.kernel is None
@@ -35,7 +35,7 @@ def test_get_runtime_interactive_ipython(self) -> None:
3535
patch(f"{self.module}._is_ci", return_value=False),
3636
patch(f"{self.module}._is_ipy", return_value=True),
3737
):
38-
runtime = get_runtime()
38+
runtime = get_execution_context()
3939
assert runtime.interactive is True
4040
assert runtime.ci is False
4141

@@ -46,7 +46,7 @@ def test_get_runtime_interactive_jupyter(self) -> None:
4646
patch(f"{self.module}._is_ipy", return_value=False),
4747
patch(f"{self.module}._is_jupyter_kernel", return_value=True),
4848
):
49-
runtime = get_runtime()
49+
runtime = get_execution_context()
5050
assert runtime.interactive is True
5151
assert runtime.ci is False
5252

@@ -57,7 +57,7 @@ def test_get_runtime_default_noninteractive(self) -> None:
5757
patch(f"{self.module}._is_ipy", return_value=False),
5858
patch(f"{self.module}._is_jupyter_kernel", return_value=False),
5959
):
60-
runtime = get_runtime()
60+
runtime = get_execution_context()
6161
assert runtime.interactive is False
6262
assert runtime.ci is False
6363

0 commit comments

Comments
 (0)