99
1010from flytekit .core .constants import FLYTE_FAIL_ON_ERROR
1111from flytekitplugins .spark .agent import DATABRICKS_API_ENDPOINT , DatabricksJobMetadata , get_header , \
12- _get_databricks_job_spec
12+ _get_databricks_job_spec , DEFAULT_DATABRICKS_INSTANCE_ENV_KEY
1313
1414from flytekit .extend .backend .base_agent import AgentRegistry
1515from flytekit .interfaces .cli_identifiers import Identifier
1616from flytekit .models import literals , task
1717from flytekit .models .core .identifier import ResourceType
1818from flytekit .models .task import Container , Resources , TaskTemplate
19+ import os
1920
20-
21- @pytest .mark .asyncio
22- async def test_databricks_agent ():
23- agent = AgentRegistry .get_agent ("spark" )
24-
21+ @pytest .fixture (scope = "function" )
22+ def task_template () -> TaskTemplate :
2523 task_id = Identifier (
2624 resource_type = ResourceType .TASK , project = "project" , domain = "domain" , name = "name" , version = "version"
2725 )
@@ -55,8 +53,7 @@ async def test_databricks_agent():
5553 },
5654 "timeout_seconds" : 3600 ,
5755 "max_retries" : 1 ,
58- },
59- "databricksInstance" : "test-account.cloud.databricks.com" ,
56+ }
6057 }
6158 container = Container (
6259 image = "flyteorg/flytekit:databricks-0.18.0-py3.7" ,
@@ -103,6 +100,16 @@ async def test_databricks_agent():
103100 interface = None ,
104101 type = "spark" ,
105102 )
103+
104+ return dummy_template
105+
106+
107+ @pytest .mark .asyncio
108+ async def test_databricks_agent (task_template : TaskTemplate ):
109+ agent = AgentRegistry .get_agent ("spark" )
110+
111+ task_template .custom ["databricksInstance" ] = "test-account.cloud.databricks.com"
112+
106113 mocked_token = "mocked_databricks_token"
107114 mocked_context = mock .patch ("flytekit.current_context" , autospec = True ).start ()
108115 mocked_context .return_value .secrets .get .return_value = mocked_token
@@ -124,8 +131,8 @@ async def test_databricks_agent():
124131 delete_url = f"https://test-account.cloud.databricks.com{ DATABRICKS_API_ENDPOINT } /runs/cancel"
125132 with aioresponses () as mocked :
126133 mocked .post (create_url , status = http .HTTPStatus .OK , payload = mock_create_response )
127- res = await agent .create (dummy_template , None )
128- spec = _get_databricks_job_spec (dummy_template )
134+ res = await agent .create (task_template , None )
135+ spec = _get_databricks_job_spec (task_template )
129136 data = json .dumps (spec )
130137 mocked .assert_called_with (create_url , method = "POST" , data = data , headers = get_header ())
131138 spark_envs = spec ["new_cluster" ]["spark_env_vars" ]
@@ -147,3 +154,39 @@ async def test_databricks_agent():
147154 assert get_header () == {"Authorization" : f"Bearer { mocked_token } " , "content-type" : "application/json" }
148155
149156 mock .patch .stopall ()
157+
158+
159+ @pytest .mark .asyncio
160+ async def test_agent_create_with_default_instance (task_template : TaskTemplate ):
161+ agent = AgentRegistry .get_agent ("spark" )
162+
163+ mocked_token = "mocked_databricks_token"
164+ mocked_context = mock .patch ("flytekit.current_context" , autospec = True ).start ()
165+ mocked_context .return_value .secrets .get .return_value = mocked_token
166+
167+ databricks_metadata = DatabricksJobMetadata (
168+ databricks_instance = "test-account.cloud.databricks.com" ,
169+ run_id = "123" ,
170+ )
171+
172+ mock_create_response = {"run_id" : "123" }
173+
174+ os .environ [DEFAULT_DATABRICKS_INSTANCE_ENV_KEY ] = "test-account.cloud.databricks.com"
175+
176+ create_url = f"https://test-account.cloud.databricks.com{ DATABRICKS_API_ENDPOINT } /runs/submit"
177+ with aioresponses () as mocked :
178+ mocked .post (create_url , status = http .HTTPStatus .OK , payload = mock_create_response )
179+ res = await agent .create (task_template , None )
180+ spec = _get_databricks_job_spec (task_template )
181+ data = json .dumps (spec )
182+ mocked .assert_called_with (create_url , method = "POST" , data = data , headers = get_header ())
183+ assert res == databricks_metadata
184+
185+ mock .patch .stopall ()
186+
187+ @pytest .mark .asyncio
188+ async def test_agent_create_with_no_instance (task_template : TaskTemplate ):
189+ agent = AgentRegistry .get_agent ("spark" )
190+
191+ with pytest .raises (ValueError ) as e :
192+ await agent .create (task_template , None )
0 commit comments