Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion elementary/clients/dbt/api_dbt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def collect_dbt_command_logs(event):
logs=[DbtLog.from_log_line(log) for log in dbt_logs],
)

return APIDbtCommandResult(success=res.success, output=output, result_obj=res)
return APIDbtCommandResult(
success=res.success, output=output, stderr=None, result_obj=res
)

def _parse_ls_command_result(
self, select: Optional[str], result: DbtCommandResult
Expand Down
28 changes: 18 additions & 10 deletions elementary/clients/dbt/command_line_dbt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
class DbtCommandResult:
success: bool
output: Optional[str]
stderr: Optional[str]


class CommandLineDbtRunner(BaseDbtRunner):
Expand Down Expand Up @@ -190,16 +191,23 @@ def run_operation(
log_pattern = (
RAW_EDR_LOGS_PATTERN if return_raw_edr_logs else MACRO_RESULT_PATTERN
)
if capture_output and result.output is not None:
for log in parse_dbt_output(result.output):
if log_errors and log.level == "error":
logger.error(log.msg)
continue

if log.msg:
match = log_pattern.match(log.msg)
if match:
run_operation_results.append(match.group(1))
if capture_output:
if result.output is not None:
for log in parse_dbt_output(result.output):
if log_errors and log.level == "error":
logger.error(log.msg)
continue

if log.msg:
match = log_pattern.match(log.msg)
if match:
run_operation_results.append(match.group(1))

if result.stderr is not None and log_errors:
for log in parse_dbt_output(result.stderr):
if log.level == "error":
logger.error(log.msg)
continue

return run_operation_results

Expand Down
14 changes: 14 additions & 0 deletions elementary/clients/dbt/dbt_fusion_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os

from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner

DBT_FUSION_PATH = os.getenv("DBT_FUSION_PATH", "~/.local/bin/dbt")


class DbtFusionRunner(SubprocessDbtRunner):
def _get_dbt_command_name(self) -> str:
return os.path.expanduser(DBT_FUSION_PATH)

def _run_deps_if_needed(self):
# Currently we don't support auto-updating deps for dbt fusion
return
47 changes: 35 additions & 12 deletions elementary/clients/dbt/factory.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
import os
from enum import Enum
from typing import Any, Dict, Optional, Type

from dbt.version import __version__ as dbt_version_string
from packaging import version

from elementary.clients.dbt.command_line_dbt_runner import CommandLineDbtRunner
from elementary.clients.dbt.dbt_fusion_runner import DbtFusionRunner
from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner

DBT_VERSION = version.Version(dbt_version_string)

RUNNER_CLASS: Type[CommandLineDbtRunner]
if (
DBT_VERSION >= version.Version("1.5.0")
and os.getenv("DBT_RUNNER_METHOD") != "subprocess"
):
from elementary.clients.dbt.api_dbt_runner import APIDbtRunner

RUNNER_CLASS = APIDbtRunner
else:
from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner

RUNNER_CLASS = SubprocessDbtRunner
class RunnerMethod(Enum):
SUBPROCESS = "subprocess"
API = "api"
FUSION = "fusion"


def create_dbt_runner(
Expand All @@ -33,8 +29,11 @@ def create_dbt_runner(
allow_macros_without_package_prefix: bool = False,
run_deps_if_needed: bool = True,
force_dbt_deps: bool = False,
runner_method: Optional[RunnerMethod] = None,
) -> CommandLineDbtRunner:
return RUNNER_CLASS(
runner_method = runner_method or get_dbt_runner_method()
runner_class = get_dbt_runner_class(runner_method)
return runner_class(
project_dir=project_dir,
profiles_dir=profiles_dir,
target=target,
Expand All @@ -46,3 +45,27 @@ def create_dbt_runner(
run_deps_if_needed=run_deps_if_needed,
force_dbt_deps=force_dbt_deps,
)


def get_dbt_runner_method() -> RunnerMethod:
runner_method = os.getenv("DBT_RUNNER_METHOD")
if runner_method:
return RunnerMethod(runner_method)

if DBT_VERSION >= version.Version("1.5.0"):
return RunnerMethod.API
return RunnerMethod.SUBPROCESS


def get_dbt_runner_class(runner_method: RunnerMethod) -> Type[CommandLineDbtRunner]:
if runner_method == RunnerMethod.API:
# Import it internally since it will fail if the dbt version is below 1.5.0
from elementary.clients.dbt.api_dbt_runner import APIDbtRunner

return APIDbtRunner
elif runner_method == RunnerMethod.SUBPROCESS:
return SubprocessDbtRunner
elif runner_method == RunnerMethod.FUSION:
return DbtFusionRunner
else:
raise ValueError(f"Invalid runner method: {runner_method}")
8 changes: 6 additions & 2 deletions elementary/clients/dbt/subprocess_dbt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ def _inner_run_command(
) -> DbtCommandResult:
try:
result = subprocess.run(
["dbt"] + dbt_command_args,
[self._get_dbt_command_name()] + dbt_command_args,
check=self.raise_on_failure,
capture_output=capture_output or quiet,
env=self._get_command_env(),
cwd=self.project_dir,
)
success = result.returncode == 0
Comment on lines 28 to 35
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add an execution timeout to avoid hung subprocesses.

Long/hung dbt processes will block indefinitely; add a configurable timeout.

-            result = subprocess.run(
+            result = subprocess.run(
                 [self._get_dbt_command_name(), *dbt_command_args],
                 check=self.raise_on_failure,
                 capture_output=capture_output or quiet,
                 text=True,
                 encoding="utf-8",
                 errors="replace",
                 env=self._get_command_env(),
                 cwd=self.project_dir,
+                timeout=float(os.getenv("EDR_DBT_SUBPROCESS_TIMEOUT_SEC", "0")) or None,
             )

Optionally document EDR_DBT_SUBPROCESS_TIMEOUT_SEC in README.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
result = subprocess.run(
["dbt"] + dbt_command_args,
[self._get_dbt_command_name()] + dbt_command_args,
check=self.raise_on_failure,
capture_output=capture_output or quiet,
env=self._get_command_env(),
cwd=self.project_dir,
)
success = result.returncode == 0
result = subprocess.run(
[self._get_dbt_command_name(), *dbt_command_args],
check=self.raise_on_failure,
capture_output=capture_output or quiet,
text=True,
encoding="utf-8",
errors="replace",
env=self._get_command_env(),
cwd=self.project_dir,
timeout=float(os.getenv("EDR_DBT_SUBPROCESS_TIMEOUT_SEC", "0")) or None,
)
success = result.returncode == 0
🧰 Tools
🪛 Ruff (0.13.1)

28-28: subprocess call: check for execution of untrusted input

(S603)


29-29: Consider [self._get_dbt_command_name(), *dbt_command_args] instead of concatenation

Replace with [self._get_dbt_command_name(), *dbt_command_args]

(RUF005)

🤖 Prompt for AI Agents
In elementary/clients/dbt/subprocess_dbt_runner.py around lines 28-35, the
subprocess.run call can hang indefinitely; add a configurable timeout by reading
EDR_DBT_SUBPROCESS_TIMEOUT_SEC from the environment (fallback to a sensible
default, e.g., 300s), convert to int, and pass it as the timeout argument to
subprocess.run; handle subprocess.TimeoutExpired by logging/raising a clear
error or marking the run as failed and ensuring the child process is cleaned up;
update any callers/tests accordingly and optionally document the new
EDR_DBT_SUBPROCESS_TIMEOUT_SEC setting in the README.

output = result.stdout.decode() if result.stdout else None
stderr = result.stderr.decode() if result.stderr else None

return DbtCommandResult(success=success, output=output)
return DbtCommandResult(success=success, output=output, stderr=stderr)
except subprocess.CalledProcessError as err:
logs = (
list(parse_dbt_output(err.output.decode(), log_format))
Expand All @@ -49,6 +50,9 @@ def _inner_run_command(
base_command_args=dbt_command_args, logs=logs, err=err
)

def _get_dbt_command_name(self) -> str:
return "dbt"

def _parse_ls_command_result(
self, select: Optional[str], result: DbtCommandResult
) -> List[str]:
Expand Down
17 changes: 4 additions & 13 deletions tests/tests_with_db/dbt_project/macros/create_all_types_table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
CURRENT_TIME() as time_col,
CURRENT_TIMESTAMP() as timestamp_col,
{% endset %}
{% set create_table_query = dbt.create_table_as(false, relation, sql_query) %}
{% do elementary.edr_log(create_table_query) %}
{% do elementary.run_query(create_table_query) %}
{% do elementary.edr_create_table_as(false, relation, sql_query) %}
{% endmacro %}

{% macro snowflake__create_all_types_table() %}
Expand Down Expand Up @@ -81,9 +79,7 @@
[1,2,3] as array_col,
TO_GEOGRAPHY('POINT(-122.35 37.55)') as geography_col
{% endset %}
{% set create_table_query = dbt.create_table_as(false, relation, sql_query) %}
{% do elementary.edr_log(create_table_query) %}
{% do elementary.run_query(create_table_query) %}
{% do elementary.edr_create_table_as(false, relation, sql_query) %}
{% endmacro %}

{% macro redshift__create_all_types_table() %}
Expand Down Expand Up @@ -123,10 +119,7 @@
ST_GeogFromText('SRID=4324;POLYGON((0 0,0 1,1 1,10 10,1 0,0 0))') as geography_col,
JSON_PARSE('{"data_type": "super"}') as super_col
{% endset %}
{% set create_table_query = dbt.create_table_as(false, relation, sql_query) %}
{% do elementary.edr_log(create_table_query) %}
{% do elementary.run_query(create_table_query) %}

{% do elementary.edr_create_table_as(false, relation, sql_query) %}
{% endmacro %}

{% macro postgres__create_all_types_table() %}
Expand Down Expand Up @@ -184,9 +177,7 @@
'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'::uuid as uuid_col,
xmlcomment('text') as xml_col
{% endset %}
{% set create_table_query = dbt.create_table_as(false, relation, sql_query) %}
{% do elementary.edr_log(create_table_query) %}
{% do elementary.run_query(create_table_query) %}
{% do elementary.edr_create_table_as(false, relation, sql_query) %}
{% endmacro %}

{% macro default__create_all_types_table() %}
Expand Down
10 changes: 5 additions & 5 deletions tests/tests_with_db/dbt_project/macros/materializations.sql
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
{% materialization test, default %}
{% if var('enable_elementary_test_materialization', false) %}
{% do return(elementary.materialization_test_default.call_macro()) %}
{% do return(elementary.materialization_test_default()) %}
{% else %}
{% do return(dbt.materialization_test_default.call_macro()) %}
{% do return(dbt.materialization_test_default()) %}
{% endif %}
{% endmaterialization %}

{% materialization test, adapter="snowflake" %}
{% if var('enable_elementary_test_materialization', false) %}
{% do return(elementary.materialization_test_snowflake.call_macro()) %}
{% do return(elementary.materialization_test_snowflake()) %}
{% else %}
{% if dbt.materialization_test_snowflake %}
{% do return(dbt.materialization_test_snowflake.call_macro()) %}
{% do return(dbt.materialization_test_snowflake()) %}
{% else %}
{% do return(dbt.materialization_test_default.call_macro()) %}
{% do return(dbt.materialization_test_default()) %}
{% endif %}
{% endif %}
{% endmaterialization %}