|
5 | 5 | import os |
6 | 6 | import sys |
7 | 7 | from dataclasses import dataclass |
8 | | -from typing import Literal |
| 8 | +from typing import Callable, Dict, Literal, Mapping, Optional, Sequence, Tuple |
| 9 | + |
| 10 | + |
| 11 | +# The type of kernel the code is running in |
| 12 | +KernelType = Literal["ipython", "jupyter", "tty"] |
| 13 | + |
| 14 | +# The type of environment the code is running in |
| 15 | +EnvironmentType = Literal[ |
| 16 | + "kaggle", |
| 17 | + "colab", |
| 18 | + "gcp", |
| 19 | + "aws", |
| 20 | + "azure", |
| 21 | + "databricks", |
| 22 | +] |
| 23 | + |
| 24 | +# Static list of environment hints, purely heuristic based on env variables. |
| 25 | +# This information is used purely for detecting purposes, and the values are |
| 26 | +# not propagated or stored in the telemetry state. |
| 27 | +ENV_TYPE_HINTS: Mapping[EnvironmentType, Sequence[str]] = { |
| 28 | + # Notebook providers |
| 29 | + "kaggle": [ |
| 30 | + # Kaggle kernels |
| 31 | + "KAGGLE_KERNEL_RUN_TYPE", |
| 32 | + "KAGGLE_URL_BASE", |
| 33 | + "KAGGLE_KERNEL_INTEGRATIONS", |
| 34 | + "KAGGLE_USER_SECRETS_TOKEN", |
| 35 | + "KAGGLE_GCP_PROJECT", |
| 36 | + "KAGGLE_GCP_ZONE", |
| 37 | + ], |
| 38 | + "colab": [ |
| 39 | + # Google Colab |
| 40 | + "COLAB_GPU", |
| 41 | + "COLAB_TPU_ADDR", |
| 42 | + "COLAB_JUPYTER_TRANSPORT", |
| 43 | + "COLAB_BACKEND_VERSION", |
| 44 | + ], |
| 45 | + "databricks": [ |
| 46 | + # Databricks clusters and runtime |
| 47 | + "DATABRICKS_RUNTIME_VERSION", |
| 48 | + "DATABRICKS_CLUSTER_ID", |
| 49 | + "DATABRICKS_HOST", |
| 50 | + "DATABRICKS_WORKSPACE_URL", |
| 51 | + "DB_IS_DRIVER", |
| 52 | + ], |
| 53 | + # Cloud providers |
| 54 | + "aws": [ |
| 55 | + # Generic AWS environment hints |
| 56 | + "AWS_EXECUTION_ENV", |
| 57 | + "AWS_REGION", |
| 58 | + "AWS_DEFAULT_REGION", |
| 59 | + # SageMaker |
| 60 | + "SM_MODEL_DIR", |
| 61 | + "SM_NUM_GPUS", |
| 62 | + "SM_HOSTS", |
| 63 | + "SM_CURRENT_HOST", |
| 64 | + "TRAINING_JOB_NAME", |
| 65 | + ], |
| 66 | + "gcp": [ |
| 67 | + # Project hints |
| 68 | + "GOOGLE_CLOUD_PROJECT", |
| 69 | + "GCP_PROJECT", |
| 70 | + "GCLOUD_PROJECT", |
| 71 | + "CLOUDSDK_CORE_PROJECT", |
| 72 | + # Cloud Run and Cloud Functions |
| 73 | + "K_SERVICE", |
| 74 | + "K_REVISION", |
| 75 | + "K_CONFIGURATION", |
| 76 | + "CLOUD_RUN_JOB", |
| 77 | + # Vertex AI |
| 78 | + "AIP_MODEL_DIR", |
| 79 | + "AIP_DATA_FORMAT", |
| 80 | + "AIP_TRAINING_DATA_URI", |
| 81 | + "CLOUD_ML_JOB_ID", |
| 82 | + # Cloud Shell |
| 83 | + "CLOUD_SHELL", |
| 84 | + ], |
| 85 | + "azure": [ |
| 86 | + # Azure ML |
| 87 | + "AZUREML_RUN_ID", |
| 88 | + "AZUREML_ARM_SUBSCRIPTION", |
| 89 | + "AZUREML_ARM_RESOURCEGROUP", |
| 90 | + "AZUREML_ARM_WORKSPACE_NAME", |
| 91 | + ], |
| 92 | +} |
9 | 93 |
|
10 | 94 |
|
11 | 95 | @dataclass |
12 | | -class Runtime: |
13 | | - """Runtime environment.""" |
| 96 | +class ExecutionContext: |
| 97 | + """The execution context of the current environment.""" |
14 | 98 |
|
15 | 99 | interactive: bool |
16 | | - kernel: Literal["ipython", "jupyter", "tty", "kaggle"] | None = None |
| 100 | + """Whether the code is running in an interactive environment.""" |
| 101 | + |
| 102 | + kernel: Optional[KernelType] = None |
| 103 | + """Low-level Python frontend or shell (e.g. IPython, Jupyter, TTY)""" |
| 104 | + |
| 105 | + environment: Optional[EnvironmentType] = None |
| 106 | + """Higher-level hosted environment or notebook platform.""" |
| 107 | + |
17 | 108 | ci: bool = False |
| 109 | + """Whether the code is running in a CI environment.""" |
18 | 110 |
|
19 | 111 |
|
20 | | -def get_runtime() -> Runtime: |
21 | | - """Get the runtime environment. |
| 112 | +def get_execution_context() -> ExecutionContext: |
| 113 | + """Get the execution context of the current environment. |
22 | 114 |
|
23 | 115 | Returns: |
24 | | - The runtime environment. |
| 116 | + The execution context of the current environment. |
25 | 117 | """ |
26 | | - # First check for Kaggle |
27 | | - if _is_kaggle(): |
28 | | - return Runtime(interactive=True, kernel="kaggle") |
| 118 | + # First check for environment |
| 119 | + environment = _get_environment() |
| 120 | + |
| 121 | + # Next, get kernel information |
| 122 | + interactive, kernel = _get_kernel() |
29 | 123 |
|
30 | | - # Next check for CI |
31 | | - if _is_ci(): |
32 | | - return Runtime(interactive=False, kernel=None, ci=True) |
| 124 | + context = ExecutionContext( |
| 125 | + interactive=interactive, kernel=kernel, environment=environment, ci=_is_ci() |
| 126 | + ) |
| 127 | + return context |
33 | 128 |
|
34 | | - # Check for IPython |
35 | | - if _is_ipy(): |
36 | | - return Runtime(interactive=True, kernel="ipython") |
37 | 129 |
|
38 | | - # Jupyter kernel |
39 | | - if _is_jupyter_kernel(): |
40 | | - return Runtime(interactive=True, kernel="jupyter") |
| 130 | +def _get_kernel() -> Tuple[bool, Optional[KernelType]]: |
| 131 | + """Get the kernel the code is running in. |
41 | 132 |
|
42 | | - # TTY |
43 | | - if _is_tty(): |
44 | | - return Runtime(interactive=True, kernel="tty") |
| 133 | + Returns: |
| 134 | + A tuple of (whether the kernel is interactive, the kernel type). |
| 135 | + """ |
| 136 | + mapping: Dict[KernelType, Callable[[], bool]] = { |
| 137 | + "ipython": _is_ipy, |
| 138 | + "jupyter": _is_jupyter_kernel, |
| 139 | + "tty": _is_tty, |
| 140 | + } |
| 141 | + for kernel, func in mapping.items(): |
| 142 | + if func(): |
| 143 | + return True, kernel |
| 144 | + return False, None |
45 | 145 |
|
46 | | - # Default to non-interactive |
47 | | - return Runtime(interactive=False, kernel=None) |
48 | 146 |
|
| 147 | +def _get_environment() -> Optional[EnvironmentType]: |
| 148 | + """Get the environment the code is running in. |
49 | 149 |
|
50 | | -def _is_kaggle() -> bool: |
51 | | - """Check if the current environment is running in a Kaggle kernel. |
| 150 | + An environment is a higher-level hosted environment or notebook platform. |
| 151 | + This is about where the code is running (Kaggle, Colab, AWS, GCP, ...). |
52 | 152 |
|
53 | 153 | Returns: |
54 | | - bool: True if the current environment is running in a Kaggle kernel. |
| 154 | + The environment the code is running in. |
55 | 155 | """ |
56 | | - # Kaggle-specific and preset env vars |
57 | | - kaggle_env_vars = [ |
58 | | - "KAGGLE_KERNEL_RUN_TYPE", |
59 | | - "KAGGLE_URL_BASE", |
60 | | - "KAGGLE_KERNEL_INTEGRATIONS", |
61 | | - "KAGGLE_USER_SECRETS_TOKEN", |
62 | | - ] |
63 | | - if any(v in os.environ for v in kaggle_env_vars): |
64 | | - return True |
| 156 | + for env_type, hints in ENV_TYPE_HINTS.items(): |
| 157 | + if any(k in os.environ for k in hints): |
| 158 | + return env_type |
65 | 159 |
|
66 | | - return False |
| 160 | + return None |
67 | 161 |
|
68 | 162 |
|
69 | 163 | def _is_ipy() -> bool: |
|
0 commit comments