Skip to content

Commit 4f347de

Browse files
authored
fix: Python models over command execution api timeout issue (#1243)
### Description Command execution api was timing out automatically at 20 minutes due to defaults with wait. My intent was to grab the command id as soon as it was available and poll, so that we would have the capability to cancel if the user killed dbt. Switching back to that approach. ### Checklist - [x] I have run this code in development and it appears to resolve the stated issue - [x] This PR includes tests, or tests are not required/relevant for this PR - [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-databricks next" section.
1 parent 8795478 commit 4f347de

File tree

6 files changed

+193
-46
lines changed

6 files changed

+193
-46
lines changed

dbt/adapters/databricks/api_client.py

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import random
23
import re
34
import time
45
from abc import ABC, abstractmethod
@@ -140,7 +141,10 @@ def _handle_start_cluster_error(self, cluster_id: str, error: Exception) -> None
140141

141142
class CommandContextApi:
142143
def __init__(
143-
self, workspace_client: WorkspaceClient, cluster_api: ClusterApi, library_api: LibraryApi
144+
self,
145+
workspace_client: WorkspaceClient,
146+
cluster_api: ClusterApi,
147+
library_api: LibraryApi,
144148
):
145149
self.workspace_client = workspace_client
146150
self.cluster_api = cluster_api
@@ -162,21 +166,92 @@ def _ensure_cluster_ready(self, cluster_id: str) -> None:
162166
):
163167
self.cluster_api.wait_for_cluster(cluster_id)
164168

165-
def _create_execution_context(self, cluster_id: str) -> str:
166-
try:
167-
result = self.workspace_client.command_execution.create(
168-
cluster_id=cluster_id,
169-
language=ComputeLanguage.PYTHON,
170-
)
169+
def _create_execution_context(self, cluster_id: str, max_retries: int = 5) -> str:
170+
"""Create execution context with retry logic for transient failures.
171171
172-
context_response = result.result()
173-
context_id = context_response.id
174-
if context_id is None:
175-
raise DbtRuntimeError("Failed to create execution context: no context ID returned")
176-
logger.info(f"Created execution context with id={context_id}")
177-
return context_id
178-
except Exception as e:
179-
raise DbtRuntimeError(f"Error creating an execution context.\n {e}")
172+
Args:
173+
cluster_id: The cluster ID to create the context on
174+
max_retries: Maximum number of retry attempts (default: 5)
175+
176+
Returns:
177+
The execution context ID
178+
179+
Raises:
180+
DbtRuntimeError: If context creation fails after all retries
181+
"""
182+
last_error = None
183+
for attempt in range(max_retries):
184+
context_id = None
185+
try:
186+
# Use SDK to create execution context - returns a Wait object
187+
# The Wait object provides context_id immediately, but we need to call result()
188+
# to wait for the context to reach RUNNING state
189+
waiter = self.workspace_client.command_execution.create(
190+
cluster_id=cluster_id,
191+
language=ComputeLanguage.PYTHON,
192+
)
193+
194+
# Get context_id immediately (available before waiting)
195+
context_id = waiter.context_id
196+
if context_id is None:
197+
raise DbtRuntimeError(
198+
"Failed to create execution context: no context ID returned"
199+
)
200+
201+
logger.debug(f"Execution context {context_id} created, waiting for RUNNING state")
202+
203+
# Now wait for the context to reach RUNNING state
204+
# This is where it may fail with ContextStatus.ERROR
205+
waiter.result()
206+
207+
logger.info(f"Execution context {context_id} reached RUNNING state")
208+
return context_id
209+
except Exception as e:
210+
last_error = e
211+
error_msg = str(e).lower()
212+
213+
# Log full exception details for debugging
214+
logger.debug(
215+
f"Execution context {context_id or 'unknown'} creation exception: "
216+
f"type={type(e).__name__}, message={e}"
217+
)
218+
219+
# Retry on transient errors (resource contention, temporary failures)
220+
# ContextStatus.ERROR can occur when cluster is under heavy load
221+
if "contextstatus.error" in error_msg or "failed to reach running" in error_msg:
222+
if attempt < max_retries - 1:
223+
# If we have a context_id, try to destroy it before retrying
224+
if context_id:
225+
try:
226+
logger.debug(f"Destroying failed context {context_id}")
227+
self.workspace_client.command_execution.destroy(
228+
cluster_id=cluster_id, context_id=context_id
229+
)
230+
except Exception as cleanup_error:
231+
logger.debug(
232+
f"Failed to destroy context {context_id}: {cleanup_error}"
233+
)
234+
235+
# Exponential backoff with jitter: base 2^attempt + random 0-1s
236+
# This helps prevent thundering herd when many contexts retry at once
237+
base_wait = 2**attempt # 1s, 2s, 4s, 8s, 16s
238+
jitter = random.random() # 0-1 second
239+
wait_time = base_wait + jitter
240+
logger.warning(
241+
f"Execution context creation failed "
242+
f"(attempt {attempt + 1}/{max_retries}), "
243+
f"retrying in {wait_time:.1f}s: {e}"
244+
)
245+
time.sleep(wait_time)
246+
continue
247+
248+
# Non-retryable error or final attempt - raise immediately
249+
raise DbtRuntimeError(f"Error creating an execution context.\n {e}")
250+
251+
# If we exhausted all retries
252+
raise DbtRuntimeError(
253+
f"Error creating an execution context after {max_retries} attempts.\n {last_error}"
254+
)
180255

181256
def destroy(self, cluster_id: str, context_id: str) -> None:
182257
try:
@@ -300,15 +375,20 @@ def __init__(self, workspace_client: WorkspaceClient, polling_interval: int, tim
300375

301376
def execute(self, cluster_id: str, context_id: str, command: str) -> CommandExecution:
302377
try:
303-
# Use SDK to execute command
304-
result = self.workspace_client.command_execution.execute(
378+
# Use SDK to execute command - returns a Wait object immediately
379+
# The command_id is available via __getattr__ without calling result()
380+
# We don't call result() because that would wait for execution to finish,
381+
# and we want to use our own timeout via poll_for_completion()
382+
waiter = self.workspace_client.command_execution.execute(
305383
cluster_id=cluster_id,
306384
context_id=context_id,
307385
language=ComputeLanguage.PYTHON, # SUBMISSION_LANGUAGE was "python"
308386
command=command,
309387
)
310388

311-
command_id = result.result().id
389+
# Extract command_id from the waiter without blocking
390+
# The SDK provides this immediately in the kwargs
391+
command_id = waiter.command_id
312392
if command_id is None:
313393
raise DbtRuntimeError("Failed to execute command: no command ID returned")
314394
logger.debug(f"Command executed with id={command_id}")

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ python = "3.10"
8888
[tool.hatch.envs.default.scripts]
8989
setup-precommit = "pre-commit install"
9090
code-quality = "pre-commit run --all-files"
91-
unit = "pytest --color=yes -v --profile databricks_cluster -n auto tests/unit"
92-
cluster-e2e = "pytest --color=yes -v --profile databricks_cluster -n auto --dist=loadfile tests/functional"
93-
uc-cluster-e2e = "pytest --color=yes -v --profile databricks_uc_cluster -n auto --dist=loadfile tests/functional"
94-
sqlw-e2e = "pytest --color=yes -v --profile databricks_uc_sql_endpoint -n auto --dist=loadfile tests/functional"
91+
unit = "pytest --color=yes -v --profile databricks_cluster -n 10 tests/unit"
92+
cluster-e2e = "pytest --color=yes -v --profile databricks_cluster -n 10 --dist=loadfile tests/functional"
93+
uc-cluster-e2e = "pytest --color=yes -v --profile databricks_uc_cluster -n 10 --dist=loadfile tests/functional"
94+
sqlw-e2e = "pytest --color=yes -v --profile databricks_uc_sql_endpoint -n 10 --dist=loadfile tests/functional"
9595

9696
[tool.hatch.envs.test.scripts]
9797
unit = "pytest --color=yes -v --profile databricks_cluster -n 10 --dist=loadscope tests/unit"

tests/functional/adapter/python_model/fixtures.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,26 @@ def model(dbt, spark):
314314
data = [[1,2]] * 10
315315
return spark.createDataFrame(data, schema=['test', 'test2'])
316316
"""
317+
318+
all_purpose_command_api_schema = """version: 2
319+
320+
models:
321+
- name: my_versioned_sql_model
322+
versions:
323+
- v: 1
324+
- name: my_python_model
325+
# No submission_method or create_notebook config here
326+
# Will use project-level config (all_purpose_cluster with create_notebook=False)
327+
328+
sources:
329+
- name: test_source
330+
loader: custom
331+
schema: "{{ var(env_var('DBT_TEST_SCHEMA_NAME_VARIABLE')) }}"
332+
quoting:
333+
identifier: True
334+
tags:
335+
- my_test_source_tag
336+
tables:
337+
- name: test_table
338+
identifier: source
339+
"""

tests/functional/adapter/python_model/test_python_model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,3 +458,32 @@ def test_changing_unique_tmp_table_suffix(self, project):
458458
)
459459
util.run_dbt(["run"])
460460
verify_temp_tables_cleaned(project)
461+
462+
463+
@pytest.mark.python
464+
@pytest.mark.skip_profile("databricks_uc_sql_endpoint")
465+
class TestAllPurposeClusterCommandAPI(BasePythonModelTests):
466+
"""Test Python models using all_purpose_cluster with Command API (create_notebook=False).
467+
468+
This tests the command execution path that uses the Command API directly
469+
without creating notebooks, which exercises the timeout fix for command execution.
470+
"""
471+
472+
@pytest.fixture(scope="class")
473+
def models(self):
474+
return {
475+
"schema.yml": override_fixtures.all_purpose_command_api_schema,
476+
"my_sql_model.sql": fixtures.basic_sql,
477+
"my_versioned_sql_model_v1.sql": fixtures.basic_sql,
478+
"my_python_model.py": fixtures.basic_python,
479+
"second_sql_model.sql": fixtures.second_sql,
480+
}
481+
482+
@pytest.fixture(scope="class")
483+
def project_config_update(self):
484+
return {
485+
"models": {
486+
"+submission_method": "all_purpose_cluster",
487+
"+create_notebook": False, # Use Command API, not notebook submission
488+
}
489+
}

tests/unit/api_client/test_command_api.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ def test_execute__exception(self, api, workspace_client):
3333
assert "Error creating a command" in str(exc_info.value)
3434

3535
def test_execute__success(self, api, workspace_client, execution):
36-
mock_result = Mock()
37-
mock_result.result.return_value.id = "command_id"
38-
workspace_client.command_execution.execute.return_value = mock_result
36+
# Mock the Wait object returned by execute()
37+
# The command_id is available immediately via __getattr__, not via result()
38+
mock_waiter = Mock()
39+
mock_waiter.command_id = "command_id"
40+
workspace_client.command_execution.execute.return_value = mock_waiter
3941

4042
result = api.execute("cluster_id", "context_id", "command")
4143

@@ -46,6 +48,8 @@ def test_execute__success(self, api, workspace_client, execution):
4648
command="command",
4749
language=ComputeLanguage.PYTHON,
4850
)
51+
# result() should NOT be called - we access command_id directly
52+
mock_waiter.result.assert_not_called()
4953

5054
def test_cancel__exception(self, api, workspace_client):
5155
workspace_client.command_execution.cancel.side_effect = Exception("API Error")

tests/unit/api_client/test_command_context_api.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from unittest.mock import Mock
22

33
import pytest
4-
from databricks.sdk.service.compute import ContextStatusResponse
54
from dbt_common.exceptions import DbtRuntimeError
65

76
from dbt.adapters.databricks.api_client import CommandContextApi
@@ -38,11 +37,14 @@ def test_create__cluster_running(self, api, cluster_api, library_api, workspace_
3837
cluster_api.status.return_value = "RUNNING"
3938
library_api.all_libraries_installed.return_value = True
4039

41-
mock_result = Mock()
42-
mock_context_response = Mock(spec=ContextStatusResponse)
43-
mock_context_response.id = "context_id"
44-
mock_result.result.return_value = mock_context_response
45-
workspace_client.command_execution.create.return_value = mock_result
40+
# Mock the Wait object returned by create()
41+
# The Wait object has context_id immediately available and result() waits for RUNNING
42+
mock_waiter = Mock()
43+
mock_waiter.context_id = "context_id"
44+
mock_response = Mock()
45+
mock_response.id = "context_id"
46+
mock_waiter.result.return_value = mock_response
47+
workspace_client.command_execution.create.return_value = mock_waiter
4648

4749
context_id = api.create("cluster_id")
4850

@@ -56,11 +58,14 @@ def test_create__cluster_running_with_pending_libraries(
5658
cluster_api.status.return_value = "RUNNING"
5759
library_api.all_libraries_installed.return_value = False
5860

59-
mock_result = Mock()
60-
mock_context_response = Mock(spec=ContextStatusResponse)
61-
mock_context_response.id = "context_id"
62-
mock_result.result.return_value = mock_context_response
63-
workspace_client.command_execution.create.return_value = mock_result
61+
# Mock the Wait object returned by create()
62+
# The Wait object has context_id immediately available and result() waits for RUNNING
63+
mock_waiter = Mock()
64+
mock_waiter.context_id = "context_id"
65+
mock_response = Mock()
66+
mock_response.id = "context_id"
67+
mock_waiter.result.return_value = mock_response
68+
workspace_client.command_execution.create.return_value = mock_waiter
6469

6570
context_id = api.create("cluster_id")
6671

@@ -71,11 +76,14 @@ def test_create__cluster_running_with_pending_libraries(
7176
def test_create__cluster_terminated(self, api, cluster_api, workspace_client):
7277
cluster_api.status.return_value = "TERMINATED"
7378

74-
mock_result = Mock()
75-
mock_context_response = Mock(spec=ContextStatusResponse)
76-
mock_context_response.id = "context_id"
77-
mock_result.result.return_value = mock_context_response
78-
workspace_client.command_execution.create.return_value = mock_result
79+
# Mock the Wait object returned by create()
80+
# The Wait object has context_id immediately available and result() waits for RUNNING
81+
mock_waiter = Mock()
82+
mock_waiter.context_id = "context_id"
83+
mock_response = Mock()
84+
mock_response.id = "context_id"
85+
mock_waiter.result.return_value = mock_response
86+
workspace_client.command_execution.create.return_value = mock_waiter
7987

8088
api.create("cluster_id")
8189

@@ -84,11 +92,14 @@ def test_create__cluster_terminated(self, api, cluster_api, workspace_client):
8492
def test_create__cluster_pending(self, api, cluster_api, workspace_client):
8593
cluster_api.status.return_value = "PENDING"
8694

87-
mock_result = Mock()
88-
mock_context_response = Mock(spec=ContextStatusResponse)
89-
mock_context_response.id = "context_id"
90-
mock_result.result.return_value = mock_context_response
91-
workspace_client.command_execution.create.return_value = mock_result
95+
# Mock the Wait object returned by create()
96+
# The Wait object has context_id immediately available and result() waits for RUNNING
97+
mock_waiter = Mock()
98+
mock_waiter.context_id = "context_id"
99+
mock_response = Mock()
100+
mock_response.id = "context_id"
101+
mock_waiter.result.return_value = mock_response
102+
workspace_client.command_execution.create.return_value = mock_waiter
92103

93104
api.create("cluster_id")
94105

0 commit comments

Comments
 (0)