Skip to content

Commit f576b89

Browse files
kzidanepietern
andauthored
Add support for dbt tasks (#520)
Co-authored-by: Pieter Noordhuis <[email protected]>
1 parent 9d8e829 commit f576b89

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

databricks_cli/sdk/service.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def create_job(
5757
job_clusters=None,
5858
tags=None,
5959
format=None,
60+
dbt_task=None,
6061
):
6162
_data = {}
6263
if name is not None:
@@ -125,6 +126,10 @@ def create_job(
125126
_data['tags'] = tags
126127
if format is not None:
127128
_data['format'] = format
129+
if dbt_task is not None:
130+
_data['dbt_task'] = dbt_task
131+
if not isinstance(dbt_task, dict):
132+
raise TypeError('Expected databricks.DbtTask() or dict for field dbt_task')
128133
return self.client.perform_query(
129134
'POST', '/jobs/create', data=_data, headers=headers, version=version
130135
)
@@ -146,6 +151,7 @@ def submit_run(
146151
idempotency_token=None,
147152
job_clusters=None,
148153
git_source=None,
154+
dbt_task=None,
149155
):
150156
_data = {}
151157
if run_name is not None:
@@ -194,6 +200,10 @@ def submit_run(
194200
_data['git_source'] = git_source
195201
if not isinstance(git_source, dict):
196202
raise TypeError('Expected databricks.GitSource() or dict for field git_source')
203+
if dbt_task is not None:
204+
_data['dbt_task'] = dbt_task
205+
if not isinstance(dbt_task, dict):
206+
raise TypeError('Expected databricks.DbtTask() or dict for field dbt_task')
197207
return self.client.perform_query(
198208
'POST', '/jobs/runs/submit', data=_data, headers=headers, version=version
199209
)
@@ -253,6 +263,7 @@ def run_now(
253263
idempotency_token=None,
254264
headers=None,
255265
version=None,
266+
dbt_commands=None,
256267
):
257268
_data = {}
258269
if job_id is not None:
@@ -269,6 +280,8 @@ def run_now(
269280
_data['python_named_params'] = python_named_params
270281
if idempotency_token is not None:
271282
_data['idempotency_token'] = idempotency_token
283+
if dbt_commands is not None:
284+
_data['dbt_commands'] = dbt_commands
272285
return self.client.perform_query(
273286
'POST', '/jobs/run-now', data=_data, headers=headers, version=version
274287
)
@@ -285,6 +298,7 @@ def repair(
285298
python_named_params=None,
286299
headers=None,
287300
version=None,
301+
dbt_commands=None,
288302
):
289303
_data = {}
290304
if run_id is not None:
@@ -303,6 +317,8 @@ def repair(
303317
_data['spark_submit_params'] = spark_submit_params
304318
if python_named_params is not None:
305319
_data['python_named_params'] = python_named_params
320+
if dbt_commands is not None:
321+
_data['dbt_commands'] = dbt_commands
306322
return self.client.perform_query(
307323
'POST', '/jobs/runs/repair', data=_data, headers=headers, version=version
308324
)

tests/sdk/test_service.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,36 @@ def test_create_job(jobs_service):
130130
jobs_service.client.perform_query.assert_called_with('POST', '/jobs/create', data={'tasks': tasks}, headers=None, version='2.1')
131131

132132

133+
@provide_conf
134+
def test_create_dbt_task(jobs_service):
135+
git_source = {
136+
'git_provider': 'github',
137+
'git_url': 'https://github.com/foo/bar',
138+
'git_branch': 'main'
139+
}
140+
141+
tasks = [
142+
{
143+
'task_key': 'dbt',
144+
'dbt_task': {
145+
'commands': ['dbt test']
146+
}
147+
}
148+
]
149+
150+
jobs_service.create_job(git_source=git_source, tasks=tasks)
151+
jobs_service.client.perform_query.assert_called_with('POST', '/jobs/create', data={'git_source': git_source, 'tasks': tasks}, headers=None, version=None)
152+
153+
154+
@provide_conf
155+
def test_run_now_dbt_task(jobs_service):
156+
job_id = 1337
157+
dbt_commands = ['dbt test', 'dbt deps']
158+
159+
jobs_service.run_now(job_id=job_id, dbt_commands=dbt_commands)
160+
jobs_service.client.perform_query.assert_called_with('POST', '/jobs/run-now', data={'job_id': job_id, 'dbt_commands': dbt_commands}, headers=None, version=None)
161+
162+
133163
@provide_conf
134164
def test_create_job_invalid_types(jobs_service):
135165
with pytest.raises(TypeError, match='new_cluster'):
@@ -140,7 +170,10 @@ def test_create_job_invalid_types(jobs_service):
140170

141171
with pytest.raises(TypeError, match='schedule'):
142172
jobs_service.create_job(schedule=[])
143-
173+
174+
with pytest.raises(TypeError, match='git_source'):
175+
jobs_service.create_job(git_source=[])
176+
144177
with pytest.raises(TypeError, match='notebook_task'):
145178
jobs_service.create_job(notebook_task=[])
146179

@@ -153,6 +186,9 @@ def test_create_job_invalid_types(jobs_service):
153186
with pytest.raises(TypeError, match='spark_submit_task'):
154187
jobs_service.create_job(spark_submit_task=[])
155188

189+
with pytest.raises(TypeError, match='dbt_task'):
190+
jobs_service.create_job(dbt_task=[])
191+
156192

157193
@provide_conf
158194
def test_submit_run_invalid_types(jobs_service):
@@ -164,7 +200,10 @@ def test_submit_run_invalid_types(jobs_service):
164200

165201
with pytest.raises(TypeError, match='schedule'):
166202
jobs_service.submit_run(schedule=[])
167-
203+
204+
with pytest.raises(TypeError, match='git_source'):
205+
jobs_service.submit_run(git_source=[])
206+
168207
with pytest.raises(TypeError, match='notebook_task'):
169208
jobs_service.submit_run(notebook_task=[])
170209

@@ -176,3 +215,6 @@ def test_submit_run_invalid_types(jobs_service):
176215

177216
with pytest.raises(TypeError, match='spark_submit_task'):
178217
jobs_service.submit_run(spark_submit_task=[])
218+
219+
with pytest.raises(TypeError, match='dbt_task'):
220+
jobs_service.submit_run(dbt_task=[])

0 commit comments

Comments
 (0)