Skip to content

Commit 08c01aa

Browse files
Add support for providing a default instance to the databricks agent (#3190)
1 parent af199b5 commit 08c01aa

File tree

2 files changed

+64
-11
lines changed

2 files changed

+64
-11
lines changed

plugins/flytekit-spark/flytekitplugins/spark/agent.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import http
22
import json
3+
import os
34
import typing
45
from dataclasses import dataclass
56
from typing import Optional
@@ -17,6 +18,7 @@
1718
aiohttp = lazy_module("aiohttp")
1819

1920
DATABRICKS_API_ENDPOINT = "/api/2.1/jobs"
21+
DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE"
2022

2123

2224
@dataclass
@@ -69,7 +71,15 @@ async def create(
6971
self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs
7072
) -> DatabricksJobMetadata:
7173
data = json.dumps(_get_databricks_job_spec(task_template))
72-
databricks_instance = task_template.custom["databricksInstance"]
74+
databricks_instance = task_template.custom.get(
75+
"databricksInstance", os.getenv(DEFAULT_DATABRICKS_INSTANCE_ENV_KEY)
76+
)
77+
78+
if not databricks_instance:
79+
raise ValueError(
80+
f"Missing databricks instance. Please set the value through the task config or set the {DEFAULT_DATABRICKS_INSTANCE_ENV_KEY} environment variable in the agent."
81+
)
82+
7383
databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/submit"
7484

7585
async with aiohttp.ClientSession() as session:

plugins/flytekit-spark/tests/test_agent.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,17 @@
99

1010
from flytekit.core.constants import FLYTE_FAIL_ON_ERROR
1111
from flytekitplugins.spark.agent import DATABRICKS_API_ENDPOINT, DatabricksJobMetadata, get_header, \
12-
_get_databricks_job_spec
12+
_get_databricks_job_spec, DEFAULT_DATABRICKS_INSTANCE_ENV_KEY
1313

1414
from flytekit.extend.backend.base_agent import AgentRegistry
1515
from flytekit.interfaces.cli_identifiers import Identifier
1616
from flytekit.models import literals, task
1717
from flytekit.models.core.identifier import ResourceType
1818
from flytekit.models.task import Container, Resources, TaskTemplate
19+
import os
1920

20-
21-
@pytest.mark.asyncio
22-
async def test_databricks_agent():
23-
agent = AgentRegistry.get_agent("spark")
24-
21+
@pytest.fixture(scope="function")
22+
def task_template() -> TaskTemplate:
2523
task_id = Identifier(
2624
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
2725
)
@@ -55,8 +53,7 @@ async def test_databricks_agent():
5553
},
5654
"timeout_seconds": 3600,
5755
"max_retries": 1,
58-
},
59-
"databricksInstance": "test-account.cloud.databricks.com",
56+
}
6057
}
6158
container = Container(
6259
image="flyteorg/flytekit:databricks-0.18.0-py3.7",
@@ -103,6 +100,16 @@ async def test_databricks_agent():
103100
interface=None,
104101
type="spark",
105102
)
103+
104+
return dummy_template
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_databricks_agent(task_template: TaskTemplate):
109+
agent = AgentRegistry.get_agent("spark")
110+
111+
task_template.custom["databricksInstance"] = "test-account.cloud.databricks.com"
112+
106113
mocked_token = "mocked_databricks_token"
107114
mocked_context = mock.patch("flytekit.current_context", autospec=True).start()
108115
mocked_context.return_value.secrets.get.return_value = mocked_token
@@ -124,8 +131,8 @@ async def test_databricks_agent():
124131
delete_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/cancel"
125132
with aioresponses() as mocked:
126133
mocked.post(create_url, status=http.HTTPStatus.OK, payload=mock_create_response)
127-
res = await agent.create(dummy_template, None)
128-
spec = _get_databricks_job_spec(dummy_template)
134+
res = await agent.create(task_template, None)
135+
spec = _get_databricks_job_spec(task_template)
129136
data = json.dumps(spec)
130137
mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header())
131138
spark_envs = spec["new_cluster"]["spark_env_vars"]
@@ -147,3 +154,39 @@ async def test_databricks_agent():
147154
assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"}
148155

149156
mock.patch.stopall()
157+
158+
159+
@pytest.mark.asyncio
160+
async def test_agent_create_with_default_instance(task_template: TaskTemplate):
161+
agent = AgentRegistry.get_agent("spark")
162+
163+
mocked_token = "mocked_databricks_token"
164+
mocked_context = mock.patch("flytekit.current_context", autospec=True).start()
165+
mocked_context.return_value.secrets.get.return_value = mocked_token
166+
167+
databricks_metadata = DatabricksJobMetadata(
168+
databricks_instance="test-account.cloud.databricks.com",
169+
run_id="123",
170+
)
171+
172+
mock_create_response = {"run_id": "123"}
173+
174+
os.environ[DEFAULT_DATABRICKS_INSTANCE_ENV_KEY] = "test-account.cloud.databricks.com"
175+
176+
create_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit"
177+
with aioresponses() as mocked:
178+
mocked.post(create_url, status=http.HTTPStatus.OK, payload=mock_create_response)
179+
res = await agent.create(task_template, None)
180+
spec = _get_databricks_job_spec(task_template)
181+
data = json.dumps(spec)
182+
mocked.assert_called_with(create_url, method="POST", data=data, headers=get_header())
183+
assert res == databricks_metadata
184+
185+
mock.patch.stopall()
186+
187+
@pytest.mark.asyncio
188+
async def test_agent_create_with_no_instance(task_template: TaskTemplate):
189+
agent = AgentRegistry.get_agent("spark")
190+
191+
with pytest.raises(ValueError) as e:
192+
await agent.create(task_template, None)

0 commit comments

Comments
 (0)