diff --git a/elementary/clients/dbt/api_dbt_runner.py b/elementary/clients/dbt/api_dbt_runner.py index c31290227..886b8caa9 100644 --- a/elementary/clients/dbt/api_dbt_runner.py +++ b/elementary/clients/dbt/api_dbt_runner.py @@ -57,7 +57,9 @@ def collect_dbt_command_logs(event): logs=[DbtLog.from_log_line(log) for log in dbt_logs], ) - return APIDbtCommandResult(success=res.success, output=output, result_obj=res) + return APIDbtCommandResult( + success=res.success, output=output, stderr=None, result_obj=res + ) def _parse_ls_command_result( self, select: Optional[str], result: DbtCommandResult diff --git a/elementary/clients/dbt/command_line_dbt_runner.py b/elementary/clients/dbt/command_line_dbt_runner.py index 668ba190b..a7261c73a 100644 --- a/elementary/clients/dbt/command_line_dbt_runner.py +++ b/elementary/clients/dbt/command_line_dbt_runner.py @@ -25,6 +25,7 @@ class DbtCommandResult: success: bool output: Optional[str] + stderr: Optional[str] class CommandLineDbtRunner(BaseDbtRunner): @@ -190,16 +191,23 @@ def run_operation( log_pattern = ( RAW_EDR_LOGS_PATTERN if return_raw_edr_logs else MACRO_RESULT_PATTERN ) - if capture_output and result.output is not None: - for log in parse_dbt_output(result.output): - if log_errors and log.level == "error": - logger.error(log.msg) - continue - - if log.msg: - match = log_pattern.match(log.msg) - if match: - run_operation_results.append(match.group(1)) + if capture_output: + if result.output is not None: + for log in parse_dbt_output(result.output): + if log_errors and log.level == "error": + logger.error(log.msg) + continue + + if log.msg: + match = log_pattern.match(log.msg) + if match: + run_operation_results.append(match.group(1)) + + if result.stderr is not None and log_errors: + for log in parse_dbt_output(result.stderr): + if log.level == "error": + logger.error(log.msg) + continue return run_operation_results diff --git a/elementary/clients/dbt/dbt_fusion_runner.py b/elementary/clients/dbt/dbt_fusion_runner.py new file mode 100644 index 000000000..408054cc1 --- /dev/null +++ b/elementary/clients/dbt/dbt_fusion_runner.py @@ -0,0 +1,14 @@ +import os + +from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner + +DBT_FUSION_PATH = os.getenv("DBT_FUSION_PATH", "~/.local/bin/dbt") + + +class DbtFusionRunner(SubprocessDbtRunner): + def _get_dbt_command_name(self) -> str: + return os.path.expanduser(DBT_FUSION_PATH) + + def _run_deps_if_needed(self): + # Currently we don't support auto-updating deps for dbt fusion + return diff --git a/elementary/clients/dbt/factory.py b/elementary/clients/dbt/factory.py index 85c0b1dca..e03bc9bb0 100644 --- a/elementary/clients/dbt/factory.py +++ b/elementary/clients/dbt/factory.py @@ -1,25 +1,21 @@ import os +from enum import Enum from typing import Any, Dict, Optional, Type from dbt.version import __version__ as dbt_version_string from packaging import version from elementary.clients.dbt.command_line_dbt_runner import CommandLineDbtRunner +from elementary.clients.dbt.dbt_fusion_runner import DbtFusionRunner +from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner DBT_VERSION = version.Version(dbt_version_string) -RUNNER_CLASS: Type[CommandLineDbtRunner] -if ( - DBT_VERSION >= version.Version("1.5.0") - and os.getenv("DBT_RUNNER_METHOD") != "subprocess" -): - from elementary.clients.dbt.api_dbt_runner import APIDbtRunner - RUNNER_CLASS = APIDbtRunner -else: - from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner - - RUNNER_CLASS = SubprocessDbtRunner +class RunnerMethod(Enum): + SUBPROCESS = "subprocess" + API = "api" + FUSION = "fusion" def create_dbt_runner( @@ -33,8 +29,11 @@ def create_dbt_runner( allow_macros_without_package_prefix: bool = False, run_deps_if_needed: bool = True, force_dbt_deps: bool = False, + runner_method: Optional[RunnerMethod] = None, ) -> CommandLineDbtRunner: - return RUNNER_CLASS( + runner_method = runner_method or get_dbt_runner_method() + runner_class = get_dbt_runner_class(runner_method) + return runner_class( project_dir=project_dir, profiles_dir=profiles_dir, target=target, @@ -46,3 +45,27 @@ def create_dbt_runner( run_deps_if_needed=run_deps_if_needed, force_dbt_deps=force_dbt_deps, ) + + +def get_dbt_runner_method() -> RunnerMethod: + runner_method = os.getenv("DBT_RUNNER_METHOD") + if runner_method: + return RunnerMethod(runner_method) + + if DBT_VERSION >= version.Version("1.5.0"): + return RunnerMethod.API + return RunnerMethod.SUBPROCESS + + +def get_dbt_runner_class(runner_method: RunnerMethod) -> Type[CommandLineDbtRunner]: + if runner_method == RunnerMethod.API: + # Import it internally since it will fail if the dbt version is below 1.5.0 + from elementary.clients.dbt.api_dbt_runner import APIDbtRunner + + return APIDbtRunner + elif runner_method == RunnerMethod.SUBPROCESS: + return SubprocessDbtRunner + elif runner_method == RunnerMethod.FUSION: + return DbtFusionRunner + else: + raise ValueError(f"Invalid runner method: {runner_method}") diff --git a/elementary/clients/dbt/subprocess_dbt_runner.py b/elementary/clients/dbt/subprocess_dbt_runner.py index 2e98db336..18d74cf23 100644 --- a/elementary/clients/dbt/subprocess_dbt_runner.py +++ b/elementary/clients/dbt/subprocess_dbt_runner.py @@ -26,7 +26,7 @@ def _inner_run_command( ) -> DbtCommandResult: try: result = subprocess.run( - ["dbt"] + dbt_command_args, + [self._get_dbt_command_name()] + dbt_command_args, check=self.raise_on_failure, capture_output=capture_output or quiet, env=self._get_command_env(), @@ -34,8 +34,9 @@ def _inner_run_command( ) success = result.returncode == 0 output = result.stdout.decode() if result.stdout else None + stderr = result.stderr.decode() if result.stderr else None - return DbtCommandResult(success=success, output=output) + return DbtCommandResult(success=success, output=output, stderr=stderr) except subprocess.CalledProcessError as err: logs = ( list(parse_dbt_output(err.output.decode(), log_format)) @@ -49,6 +50,9 @@ def _inner_run_command( base_command_args=dbt_command_args, logs=logs, err=err ) + def _get_dbt_command_name(self) -> str: + return "dbt" + def _parse_ls_command_result( self, select: Optional[str], result: DbtCommandResult ) -> List[str]: diff --git a/tests/tests_with_db/dbt_project/macros/create_all_types_table.sql b/tests/tests_with_db/dbt_project/macros/create_all_types_table.sql index 89d541887..6c587437f 100644 --- a/tests/tests_with_db/dbt_project/macros/create_all_types_table.sql +++ b/tests/tests_with_db/dbt_project/macros/create_all_types_table.sql @@ -31,9 +31,7 @@ CURRENT_TIME() as time_col, CURRENT_TIMESTAMP() as timestamp_col, {% endset %} - {% set create_table_query = dbt.create_table_as(false, relation, sql_query) %} - {% do elementary.edr_log(create_table_query) %} - {% do elementary.run_query(create_table_query) %} + {% do elementary.edr_create_table_as(false, relation, sql_query) %} {% endmacro %} {% macro snowflake__create_all_types_table() %} @@ -81,9 +79,7 @@ [1,2,3] as array_col, TO_GEOGRAPHY('POINT(-122.35 37.55)') as geography_col {% endset %} - {% set create_table_query = dbt.create_table_as(false, relation, sql_query) %} - {% do elementary.edr_log(create_table_query) %} - {% do elementary.run_query(create_table_query) %} + {% do elementary.edr_create_table_as(false, relation, sql_query) %} {% endmacro %} {% macro redshift__create_all_types_table() %} @@ -123,10 +119,7 @@ ST_GeogFromText('SRID=4324;POLYGON((0 0,0 1,1 1,10 10,1 0,0 0))') as geography_col, JSON_PARSE('{"data_type": "super"}') as super_col {% endset %} - {% set create_table_query = dbt.create_table_as(false, relation, sql_query) %} - {% do elementary.edr_log(create_table_query) %} - {% do elementary.run_query(create_table_query) %} - + {% do elementary.edr_create_table_as(false, relation, sql_query) %} {% endmacro %} {% macro postgres__create_all_types_table() %} @@ -184,9 +177,7 @@ 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'::uuid as uuid_col, xmlcomment('text') as xml_col {% endset %} - {% set create_table_query = dbt.create_table_as(false, relation, sql_query) %} - {% do elementary.edr_log(create_table_query) %} - {% do elementary.run_query(create_table_query) %} + {% do elementary.edr_create_table_as(false, relation, sql_query) %} {% endmacro %} {% macro default__create_all_types_table() %} diff --git a/tests/tests_with_db/dbt_project/macros/materializations.sql b/tests/tests_with_db/dbt_project/macros/materializations.sql index 51ae1c98e..cb1c5a450 100644 --- a/tests/tests_with_db/dbt_project/macros/materializations.sql +++ b/tests/tests_with_db/dbt_project/macros/materializations.sql @@ -1,19 +1,19 @@ {% materialization test, default %} {% if var('enable_elementary_test_materialization', false) %} - {% do return(elementary.materialization_test_default.call_macro()) %} + {% do return(elementary.materialization_test_default()) %} {% else %} - {% do return(dbt.materialization_test_default.call_macro()) %} + {% do return(dbt.materialization_test_default()) %} {% endif %} {% endmaterialization %} {% materialization test, adapter="snowflake" %} {% if var('enable_elementary_test_materialization', false) %} - {% do return(elementary.materialization_test_snowflake.call_macro()) %} + {% do return(elementary.materialization_test_snowflake()) %} {% else %} {% if dbt.materialization_test_snowflake %} - {% do return(dbt.materialization_test_snowflake.call_macro()) %} + {% do return(dbt.materialization_test_snowflake()) %} {% else %} - {% do return(dbt.materialization_test_default.call_macro()) %} + {% do return(dbt.materialization_test_default()) %} {% endif %} {% endif %} {% endmaterialization %}