Skip to content

Commit 84b3901

Browse files
authored
Fix: Automatically repair dbt test circular references by moving upstream test to downstream model (#5253)
1 parent ecdbbde commit 84b3901

File tree

2 files changed

+45
-31
lines changed

2 files changed

+45
-31
lines changed

sqlmesh/dbt/basemodel.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from abc import abstractmethod
55
from enum import Enum
66
from pathlib import Path
7+
import logging
78

89
from pydantic import Field
910
from sqlglot.helper import ensure_list
@@ -38,6 +39,9 @@
3839
BMC = t.TypeVar("BMC", bound="BaseModelConfig")
3940

4041

42+
logger = logging.getLogger(__name__)
43+
44+
4145
class Materialization(str, Enum):
4246
"""DBT model materializations"""
4347

@@ -262,37 +266,32 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None:
262266
and all(source in context.sources for source in test.dependencies.sources)
263267
]
264268

265-
def check_for_circular_test_refs(self, context: DbtContext) -> None:
269+
def fix_circular_test_refs(self, context: DbtContext) -> None:
266270
"""
267-
Checks for direct circular references between two models and raises an exception if found.
268-
This addresses the most common circular reference seen when importing a dbt project -
269-
relationship tests in both directions. In the future, we may want to increase coverage by
270-
checking for indirect circular references.
271+
Checks for direct circular references between two models and moves the test to the downstream
272+
model if found. This addresses the most common circular reference - relationship tests in both
273+
directions. In the future, we may want to increase coverage by checking for indirect circular references.
271274
272275
Args:
273276
context: The dbt context this model resides within.
274277
275278
Returns:
276279
None
277280
"""
278-
for test in self.tests:
281+
for test in self.tests.copy():
279282
for ref in test.dependencies.refs:
280-
model = context.refs[ref]
281283
if ref == self.name or ref in self.dependencies.refs:
282284
continue
283-
elif self.name in model.dependencies.refs:
284-
raise ConfigError(
285-
f"Test '{test.name}' for model '{self.name}' depends on downstream model '{model.name}'."
286-
" Move the test to the downstream model to avoid circular references."
287-
)
288-
elif self.name in model.tests_ref_source_dependencies.refs:
289-
circular_test = next(
290-
test.name for test in model.tests if ref in test.dependencies.refs
291-
)
292-
raise ConfigError(
293-
f"Circular reference detected between tests for models '{self.name}' and '{model.name}':"
294-
f" '{test.name}' ({self.name}), '{circular_test}' ({model.name})."
285+
model = context.refs[ref]
286+
if (
287+
self.name in model.dependencies.refs
288+
or self.name in model.tests_ref_source_dependencies.refs
289+
):
290+
logger.info(
291+
f"Moving test '{test.name}' from model '{self.name}' to '{model.name}' to avoid circular reference."
295292
)
293+
model.tests.append(test)
294+
self.tests.remove(test)
296295

297296
@property
298297
def sqlmesh_config_fields(self) -> t.Set[str]:
@@ -313,7 +312,7 @@ def sqlmesh_model_kwargs(
313312
) -> t.Dict[str, t.Any]:
314313
"""Get common sqlmesh model parameters"""
315314
self.remove_tests_with_invalid_refs(context)
316-
self.check_for_circular_test_refs(context)
315+
self.fix_circular_test_refs(context)
317316

318317
dependencies = self.dependencies.copy()
319318
if dependencies.has_dynamic_var_names:

tests/dbt/test_model.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from sqlmesh.dbt.model import ModelConfig
99
from sqlmesh.dbt.target import PostgresConfig
1010
from sqlmesh.dbt.test import TestConfig
11-
from sqlmesh.utils.errors import ConfigError
1211
from sqlmesh.utils.yaml import YAML
1312

1413
pytestmark = pytest.mark.dbt
@@ -30,25 +29,41 @@ def test_model_test_circular_references() -> None:
3029
sql="",
3130
dependencies=Dependencies(refs={"upstream", "downstream"}),
3231
)
32+
33+
# No circular reference
3334
downstream_model.tests = [downstream_test]
34-
downstream_model.check_for_circular_test_refs(context)
35+
downstream_model.fix_circular_test_refs(context)
36+
assert upstream_model.tests == []
37+
assert downstream_model.tests == [downstream_test]
3538

39+
# Upstream model reference in downstream model
3640
downstream_model.tests = []
3741
upstream_model.tests = [upstream_test]
38-
with pytest.raises(ConfigError, match="downstream model"):
39-
upstream_model.check_for_circular_test_refs(context)
42+
upstream_model.fix_circular_test_refs(context)
43+
assert upstream_model.tests == []
44+
assert downstream_model.tests == [upstream_test]
4045

46+
upstream_model.tests = [upstream_test]
4147
downstream_model.tests = [downstream_test]
42-
with pytest.raises(ConfigError, match="downstream model"):
43-
upstream_model.check_for_circular_test_refs(context)
44-
downstream_model.check_for_circular_test_refs(context)
48+
upstream_model.fix_circular_test_refs(context)
49+
assert upstream_model.tests == []
50+
assert downstream_model.tests == [downstream_test, upstream_test]
51+
52+
downstream_model.fix_circular_test_refs(context)
53+
assert upstream_model.tests == []
54+
assert downstream_model.tests == [downstream_test, upstream_test]
4555

4656
# Test only references
57+
upstream_model.tests = [upstream_test]
58+
downstream_model.tests = [downstream_test]
4759
downstream_model.dependencies = Dependencies()
48-
with pytest.raises(ConfigError, match="between tests"):
49-
upstream_model.check_for_circular_test_refs(context)
50-
with pytest.raises(ConfigError, match="between tests"):
51-
downstream_model.check_for_circular_test_refs(context)
60+
upstream_model.fix_circular_test_refs(context)
61+
assert upstream_model.tests == []
62+
assert downstream_model.tests == [downstream_test, upstream_test]
63+
64+
downstream_model.fix_circular_test_refs(context)
65+
assert upstream_model.tests == []
66+
assert downstream_model.tests == [downstream_test, upstream_test]
5267

5368

5469
@pytest.mark.slow

0 commit comments

Comments
 (0)