Skip to content

Commit 440fd9d

Browse files
authored
Ele 4724 dbt fusion support (#1961)
* add dbt fusion runner * remove call_macro * fixes * bugfix * bugfix - parent dataclass can't have defaults * replace dbt.create_table_as with our implementation * don't run deps from the runner for dbt fusion
1 parent 44bb802 commit 440fd9d

File tree

7 files changed

+85
-43
lines changed

7 files changed

+85
-43
lines changed

elementary/clients/dbt/api_dbt_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def collect_dbt_command_logs(event):
5757
logs=[DbtLog.from_log_line(log) for log in dbt_logs],
5858
)
5959

60-
return APIDbtCommandResult(success=res.success, output=output, result_obj=res)
60+
return APIDbtCommandResult(
61+
success=res.success, output=output, stderr=None, result_obj=res
62+
)
6163

6264
def _parse_ls_command_result(
6365
self, select: Optional[str], result: DbtCommandResult

elementary/clients/dbt/command_line_dbt_runner.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
class DbtCommandResult:
2626
success: bool
2727
output: Optional[str]
28+
stderr: Optional[str]
2829

2930

3031
class CommandLineDbtRunner(BaseDbtRunner):
@@ -190,16 +191,23 @@ def run_operation(
190191
log_pattern = (
191192
RAW_EDR_LOGS_PATTERN if return_raw_edr_logs else MACRO_RESULT_PATTERN
192193
)
193-
if capture_output and result.output is not None:
194-
for log in parse_dbt_output(result.output):
195-
if log_errors and log.level == "error":
196-
logger.error(log.msg)
197-
continue
198-
199-
if log.msg:
200-
match = log_pattern.match(log.msg)
201-
if match:
202-
run_operation_results.append(match.group(1))
194+
if capture_output:
195+
if result.output is not None:
196+
for log in parse_dbt_output(result.output):
197+
if log_errors and log.level == "error":
198+
logger.error(log.msg)
199+
continue
200+
201+
if log.msg:
202+
match = log_pattern.match(log.msg)
203+
if match:
204+
run_operation_results.append(match.group(1))
205+
206+
if result.stderr is not None and log_errors:
207+
for log in parse_dbt_output(result.stderr):
208+
if log.level == "error":
209+
logger.error(log.msg)
210+
continue
203211

204212
return run_operation_results
205213

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import os
2+
3+
from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner
4+
5+
DBT_FUSION_PATH = os.getenv("DBT_FUSION_PATH", "~/.local/bin/dbt")
6+
7+
8+
class DbtFusionRunner(SubprocessDbtRunner):
9+
def _get_dbt_command_name(self) -> str:
10+
return os.path.expanduser(DBT_FUSION_PATH)
11+
12+
def _run_deps_if_needed(self):
13+
# Currently we don't support auto-updating deps for dbt fusion
14+
return

elementary/clients/dbt/factory.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
import os
2+
from enum import Enum
23
from typing import Any, Dict, Optional, Type
34

45
from dbt.version import __version__ as dbt_version_string
56
from packaging import version
67

78
from elementary.clients.dbt.command_line_dbt_runner import CommandLineDbtRunner
9+
from elementary.clients.dbt.dbt_fusion_runner import DbtFusionRunner
10+
from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner
811

912
DBT_VERSION = version.Version(dbt_version_string)
1013

11-
RUNNER_CLASS: Type[CommandLineDbtRunner]
12-
if (
13-
DBT_VERSION >= version.Version("1.5.0")
14-
and os.getenv("DBT_RUNNER_METHOD") != "subprocess"
15-
):
16-
from elementary.clients.dbt.api_dbt_runner import APIDbtRunner
1714

18-
RUNNER_CLASS = APIDbtRunner
19-
else:
20-
from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner
21-
22-
RUNNER_CLASS = SubprocessDbtRunner
15+
class RunnerMethod(Enum):
16+
SUBPROCESS = "subprocess"
17+
API = "api"
18+
FUSION = "fusion"
2319

2420

2521
def create_dbt_runner(
@@ -33,8 +29,11 @@ def create_dbt_runner(
3329
allow_macros_without_package_prefix: bool = False,
3430
run_deps_if_needed: bool = True,
3531
force_dbt_deps: bool = False,
32+
runner_method: Optional[RunnerMethod] = None,
3633
) -> CommandLineDbtRunner:
37-
return RUNNER_CLASS(
34+
runner_method = runner_method or get_dbt_runner_method()
35+
runner_class = get_dbt_runner_class(runner_method)
36+
return runner_class(
3837
project_dir=project_dir,
3938
profiles_dir=profiles_dir,
4039
target=target,
@@ -46,3 +45,27 @@ def create_dbt_runner(
4645
run_deps_if_needed=run_deps_if_needed,
4746
force_dbt_deps=force_dbt_deps,
4847
)
48+
49+
50+
def get_dbt_runner_method() -> RunnerMethod:
51+
runner_method = os.getenv("DBT_RUNNER_METHOD")
52+
if runner_method:
53+
return RunnerMethod(runner_method)
54+
55+
if DBT_VERSION >= version.Version("1.5.0"):
56+
return RunnerMethod.API
57+
return RunnerMethod.SUBPROCESS
58+
59+
60+
def get_dbt_runner_class(runner_method: RunnerMethod) -> Type[CommandLineDbtRunner]:
61+
if runner_method == RunnerMethod.API:
62+
# Import it internally since it will fail if the dbt version is below 1.5.0
63+
from elementary.clients.dbt.api_dbt_runner import APIDbtRunner
64+
65+
return APIDbtRunner
66+
elif runner_method == RunnerMethod.SUBPROCESS:
67+
return SubprocessDbtRunner
68+
elif runner_method == RunnerMethod.FUSION:
69+
return DbtFusionRunner
70+
else:
71+
raise ValueError(f"Invalid runner method: {runner_method}")

elementary/clients/dbt/subprocess_dbt_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,17 @@ def _inner_run_command(
2626
) -> DbtCommandResult:
2727
try:
2828
result = subprocess.run(
29-
["dbt"] + dbt_command_args,
29+
[self._get_dbt_command_name()] + dbt_command_args,
3030
check=self.raise_on_failure,
3131
capture_output=capture_output or quiet,
3232
env=self._get_command_env(),
3333
cwd=self.project_dir,
3434
)
3535
success = result.returncode == 0
3636
output = result.stdout.decode() if result.stdout else None
37+
stderr = result.stderr.decode() if result.stderr else None
3738

38-
return DbtCommandResult(success=success, output=output)
39+
return DbtCommandResult(success=success, output=output, stderr=stderr)
3940
except subprocess.CalledProcessError as err:
4041
logs = (
4142
list(parse_dbt_output(err.output.decode(), log_format))
@@ -49,6 +50,9 @@ def _inner_run_command(
4950
base_command_args=dbt_command_args, logs=logs, err=err
5051
)
5152

53+
def _get_dbt_command_name(self) -> str:
54+
return "dbt"
55+
5256
def _parse_ls_command_result(
5357
self, select: Optional[str], result: DbtCommandResult
5458
) -> List[str]:

tests/tests_with_db/dbt_project/macros/create_all_types_table.sql

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
CURRENT_TIME() as time_col,
3232
CURRENT_TIMESTAMP() as timestamp_col,
3333
{% endset %}
34-
{% set create_table_query = dbt.create_table_as(false, relation, sql_query) %}
35-
{% do elementary.edr_log(create_table_query) %}
36-
{% do elementary.run_query(create_table_query) %}
34+
{% do elementary.edr_create_table_as(false, relation, sql_query) %}
3735
{% endmacro %}
3836

3937
{% macro snowflake__create_all_types_table() %}
@@ -81,9 +79,7 @@
8179
[1,2,3] as array_col,
8280
TO_GEOGRAPHY('POINT(-122.35 37.55)') as geography_col
8381
{% endset %}
84-
{% set create_table_query = dbt.create_table_as(false, relation, sql_query) %}
85-
{% do elementary.edr_log(create_table_query) %}
86-
{% do elementary.run_query(create_table_query) %}
82+
{% do elementary.edr_create_table_as(false, relation, sql_query) %}
8783
{% endmacro %}
8884

8985
{% macro redshift__create_all_types_table() %}
@@ -123,10 +119,7 @@
123119
ST_GeogFromText('SRID=4324;POLYGON((0 0,0 1,1 1,10 10,1 0,0 0))') as geography_col,
124120
JSON_PARSE('{"data_type": "super"}') as super_col
125121
{% endset %}
126-
{% set create_table_query = dbt.create_table_as(false, relation, sql_query) %}
127-
{% do elementary.edr_log(create_table_query) %}
128-
{% do elementary.run_query(create_table_query) %}
129-
122+
{% do elementary.edr_create_table_as(false, relation, sql_query) %}
130123
{% endmacro %}
131124

132125
{% macro postgres__create_all_types_table() %}
@@ -184,9 +177,7 @@
184177
'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'::uuid as uuid_col,
185178
xmlcomment('text') as xml_col
186179
{% endset %}
187-
{% set create_table_query = dbt.create_table_as(false, relation, sql_query) %}
188-
{% do elementary.edr_log(create_table_query) %}
189-
{% do elementary.run_query(create_table_query) %}
180+
{% do elementary.edr_create_table_as(false, relation, sql_query) %}
190181
{% endmacro %}
191182

192183
{% macro default__create_all_types_table() %}
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
{% materialization test, default %}
22
{% if var('enable_elementary_test_materialization', false) %}
3-
{% do return(elementary.materialization_test_default.call_macro()) %}
3+
{% do return(elementary.materialization_test_default()) %}
44
{% else %}
5-
{% do return(dbt.materialization_test_default.call_macro()) %}
5+
{% do return(dbt.materialization_test_default()) %}
66
{% endif %}
77
{% endmaterialization %}
88

99
{% materialization test, adapter="snowflake" %}
1010
{% if var('enable_elementary_test_materialization', false) %}
11-
{% do return(elementary.materialization_test_snowflake.call_macro()) %}
11+
{% do return(elementary.materialization_test_snowflake()) %}
1212
{% else %}
1313
{% if dbt.materialization_test_snowflake %}
14-
{% do return(dbt.materialization_test_snowflake.call_macro()) %}
14+
{% do return(dbt.materialization_test_snowflake()) %}
1515
{% else %}
16-
{% do return(dbt.materialization_test_default.call_macro()) %}
16+
{% do return(dbt.materialization_test_default()) %}
1717
{% endif %}
1818
{% endif %}
1919
{% endmaterialization %}

0 commit comments

Comments
 (0)