Skip to content

Commit d4b6ea7

Browse files
datnguyeil-dat
andauthored
feat: New algo - model contract contraints (#140)
* feat: New algo - model contract contraints * fix: complete testing with sample artifacts * feat: Update relation names and refactoring in model contract tests * feat(core): Support PK and multiple columns in relationships * fix(drawdb): lint errors * fix: update model relationships and constraints in integration tests --------- Co-authored-by: Dat Nguyen <dat@infinitelambda.com>
1 parent 0505424 commit d4b6ea7

37 files changed

+34570
-71
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ A huge thanks to our amazing contributors! 🙏
194194

195195
**Need help?** We're here for you! Check 📖 [Documentation](https://dbterd.datnguyen.de/), 🐛 [Report Issues](https://github.com/datnguye/dbterd/issues) and 💬 [Discussions](https://github.com/datnguye/dbterd/discussions)
196196

197+
198+
[![Star History Chart](https://api.star-history.com/image?repos=datnguye/dbterd&type=date&legend=top-left)](https://www.star-history.com/?repos=datnguye%2Fdbterd&type=date&legend=top-left)
199+
197200
---
198201

199202
<div align="center">
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
"""Model contract algorithm adapter for dbterd.
2+
3+
This module extracts tables and relationships from dbt artifacts
4+
using dbt model contract constraints (foreign_key) to determine connections.
5+
Requires manifest v12+ (dbt 1.9+) for the `to` and `to_columns` fields.
6+
"""
7+
8+
from typing import Optional, Union
9+
10+
from dbterd.constants import TEST_META_RELATIONSHIP_TYPE
11+
from dbterd.core.adapters.algo import BaseAlgoAdapter
12+
from dbterd.core.models import Ref, Table
13+
from dbterd.core.registry.decorators import register_algo
14+
from dbterd.helpers.log import logger
15+
from dbterd.types import Catalog, Manifest
16+
17+
18+
def _resolve_to_node_id(to_str: str, manifest_nodes: dict) -> Optional[str]:
19+
"""Resolve constraint.to to a manifest node unique ID.
20+
21+
The constraint.to value is a fully qualified relation name in the format:
22+
<database>.<schema>.<table_name> (e.g. "shaman.dummy.locations")
23+
24+
This is matched against each node's relation_name field.
25+
Only model resource type is currently supported.
26+
27+
Args:
28+
to_str: The constraint.to string (<database>.<schema>.<table_name>)
29+
manifest_nodes: Dict of manifest node IDs to node objects
30+
31+
Returns:
32+
Matching node unique ID, or None if not found.
33+
34+
"""
35+
if not to_str:
36+
return None
37+
38+
# model resource type takes priority over other resource types
39+
sorted_nodes = sorted(manifest_nodes.items(), key=lambda x: (0 if x[0].startswith("model.") else 1))
40+
for node_id, node in sorted_nodes:
41+
if getattr(node, "relation_name", None) == to_str:
42+
return node_id
43+
44+
return None
45+
46+
47+
def _get_relationship_type(meta_value: str) -> str:
48+
"""Get short form of the relationship type from meta.
49+
50+
Args:
51+
meta_value: Meta relationship_type value
52+
53+
Returns:
54+
Short relationship type code.
55+
56+
"""
57+
mapping = {
58+
"zero-to-many": "0n",
59+
"zero-to-one": "01",
60+
"one-to-one": "11",
61+
"many-to-many": "nn",
62+
"one-to-many": "1n",
63+
}
64+
return mapping.get(meta_value.lower(), "n1")
65+
66+
67+
def _extract_pk_column_names(node) -> list[str]:
68+
"""Extract primary key column names from a manifest node's constraints.
69+
70+
Checks both model-level constraints (constraint.columns where type=primary_key)
71+
and column-level constraints (column.constraints where type=primary_key).
72+
73+
Args:
74+
node: The manifest node object
75+
76+
Returns:
77+
List of column names that are part of the primary key.
78+
79+
"""
80+
pk_columns = []
81+
82+
if hasattr(node, "constraints") and node.constraints:
83+
for constraint in node.constraints:
84+
if constraint.type.value == "primary_key" and getattr(constraint, "columns", None):
85+
pk_columns.extend(constraint.columns)
86+
87+
if hasattr(node, "columns") and node.columns:
88+
for col_name, col in node.columns.items():
89+
if not hasattr(col, "constraints") or not col.constraints:
90+
continue
91+
for constraint in col.constraints:
92+
if constraint.type.value == "primary_key":
93+
pk_columns.append(col_name)
94+
95+
return pk_columns
96+
97+
98+
@register_algo("model_contract", description="Detect relationships via dbt model contract constraints")
99+
class ModelContractAlgo(BaseAlgoAdapter):
100+
"""Algorithm adapter using dbt model contract constraints.
101+
102+
Extracts relationships from dbt's model contract foreign_key constraints
103+
(available in manifest v12+ / dbt 1.9+) to determine table connections.
104+
"""
105+
106+
def parse_artifacts(self, manifest: Manifest, catalog: Catalog, **kwargs) -> tuple[list[Table], list[Ref]]:
107+
"""Parse from file-based manifest/catalog artifacts."""
108+
tables = self.get_tables(manifest=manifest, catalog=catalog, **kwargs)
109+
tables = self.filter_tables_based_on_selection(tables=tables, **kwargs)
110+
tables = self._enrich_tables_with_pk_info(tables=tables, manifest=manifest)
111+
112+
relationships = self.get_relationships(manifest=manifest, **kwargs)
113+
relationships = self.make_up_relationships(relationships=relationships, tables=tables)
114+
115+
tables = self.enrich_tables_from_relationships(tables=tables, relationships=relationships)
116+
117+
logger.info(f"Collected {len(tables)} table(s) and {len(relationships)} relationship(s)")
118+
return (
119+
sorted(tables, key=lambda tbl: tbl.node_name),
120+
sorted(relationships, key=lambda rel: rel.name),
121+
)
122+
123+
def _enrich_tables_with_pk_info(self, tables: list[Table], manifest: Manifest) -> list[Table]:
124+
"""Mark columns as primary key based on manifest constraints.
125+
126+
Args:
127+
tables: List of parsed tables
128+
manifest: Manifest data
129+
130+
Returns:
131+
Tables with is_primary_key set on relevant columns.
132+
133+
"""
134+
if not hasattr(manifest, "nodes"):
135+
return tables
136+
137+
pk_map: dict[str, set[str]] = {}
138+
for node_name, node in manifest.nodes.items():
139+
pk_cols = _extract_pk_column_names(node)
140+
if pk_cols:
141+
pk_map[node_name] = {c.lower() for c in pk_cols}
142+
143+
for table in tables:
144+
pks = pk_map.get(table.node_name, set())
145+
if not pks:
146+
continue
147+
for col in table.columns:
148+
if col.name.lower() in pks:
149+
col.is_primary_key = True
150+
151+
return tables
152+
153+
def parse_metadata(self, data: dict, **kwargs) -> tuple[list[Table], list[Ref]]:
154+
"""Parse from dbt Cloud metadata API response.
155+
156+
Not supported for model_contract algorithm.
157+
"""
158+
logger.warning("model_contract algorithm does not support dbt Cloud metadata API. Returning empty results.")
159+
return ([], [])
160+
161+
def find_related_nodes_by_id(
162+
self,
163+
manifest: Union[Manifest, dict],
164+
node_unique_id: str,
165+
type: Optional[str] = None,
166+
**kwargs,
167+
) -> list[str]:
168+
"""Find FK models related to the input model ID via constraints.
169+
170+
Args:
171+
manifest: Manifest data
172+
node_unique_id: Manifest node unique ID
173+
type: Manifest type (local file or metadata)
174+
**kwargs: Additional options
175+
176+
Returns:
177+
List of manifest node unique IDs
178+
179+
"""
180+
found_nodes = [node_unique_id]
181+
if type == "metadata":
182+
return found_nodes
183+
184+
if not hasattr(manifest, "nodes"):
185+
return found_nodes
186+
187+
for node_name, node in manifest.nodes.items():
188+
if not node_name.startswith("model."):
189+
continue
190+
191+
fk_targets = self._collect_fk_targets(node, manifest.nodes)
192+
193+
for target_id in fk_targets:
194+
if node_name == node_unique_id:
195+
found_nodes.append(target_id)
196+
elif target_id == node_unique_id:
197+
found_nodes.append(node_name)
198+
199+
return list(set(found_nodes))
200+
201+
def get_relationships(self, manifest: Manifest, **kwargs) -> list[Ref]:
202+
"""Extract relationships from model contract constraints.
203+
204+
Scans model.* nodes for column-level and model-level foreign_key
205+
constraints with populated `to` fields (manifest v12+).
206+
207+
Args:
208+
manifest: Manifest data
209+
**kwargs: Additional options
210+
211+
Returns:
212+
List of parsed relationships
213+
214+
"""
215+
if not hasattr(manifest, "nodes"):
216+
return []
217+
218+
refs = []
219+
220+
for node_name, node in manifest.nodes.items():
221+
if not node_name.startswith("model."):
222+
continue
223+
224+
refs.extend(self._extract_column_level_refs(node_name, node, manifest.nodes))
225+
refs.extend(self._extract_model_level_refs(node_name, node, manifest.nodes))
226+
227+
return self.get_unique_refs(refs=refs)
228+
229+
def _collect_fk_targets(self, node, manifest_nodes: dict) -> list[str]:
230+
"""Collect all FK target node IDs from a node's constraints.
231+
232+
Args:
233+
node: Manifest node object
234+
manifest_nodes: Dict of all manifest nodes
235+
236+
Returns:
237+
List of target node IDs
238+
239+
"""
240+
targets = []
241+
242+
if hasattr(node, "columns") and node.columns:
243+
for col in node.columns.values():
244+
if not hasattr(col, "constraints") or not col.constraints:
245+
continue
246+
for constraint in col.constraints:
247+
if constraint.type.value == "foreign_key" and getattr(constraint, "to", None):
248+
target_id = _resolve_to_node_id(constraint.to, manifest_nodes)
249+
if target_id:
250+
targets.append(target_id)
251+
252+
if hasattr(node, "constraints") and node.constraints:
253+
for constraint in node.constraints:
254+
if constraint.type.value == "foreign_key" and getattr(constraint, "to", None):
255+
target_id = _resolve_to_node_id(constraint.to, manifest_nodes)
256+
if target_id:
257+
targets.append(target_id)
258+
259+
return targets
260+
261+
def _extract_column_level_refs(self, node_name: str, node, manifest_nodes: dict) -> list[Ref]:
262+
"""Extract Ref objects from column-level FK constraints.
263+
264+
Args:
265+
node_name: The node unique ID (e.g. model.pkg.orders)
266+
node: The manifest node object
267+
manifest_nodes: Dict of all manifest nodes
268+
269+
Returns:
270+
List of Ref objects
271+
272+
"""
273+
refs = []
274+
275+
if not hasattr(node, "columns") or not node.columns:
276+
return refs
277+
278+
for col_name, col in node.columns.items():
279+
if not hasattr(col, "constraints") or not col.constraints:
280+
continue
281+
282+
for constraint in col.constraints:
283+
if constraint.type.value != "foreign_key":
284+
continue
285+
if not getattr(constraint, "to", None):
286+
continue
287+
288+
to_node_id = _resolve_to_node_id(constraint.to, manifest_nodes)
289+
if not to_node_id:
290+
continue
291+
292+
to_columns = getattr(constraint, "to_columns", None) or [col_name]
293+
col_meta = getattr(col, "meta", None) or {}
294+
relationship_type = _get_relationship_type(col_meta.get(TEST_META_RELATIONSHIP_TYPE, ""))
295+
296+
constraint_name = getattr(constraint, "name", None) or None
297+
node_meta = getattr(node, "meta", None) or {}
298+
relationship_label = node_meta.get("relationship_labels", {}).get(constraint_name)
299+
for to_column in to_columns:
300+
refs.append(
301+
Ref(
302+
name=constraint_name or node_name,
303+
table_map=[to_node_id, node_name],
304+
column_map=([to_column], [col_name]),
305+
type=relationship_type,
306+
relationship_label=relationship_label,
307+
)
308+
)
309+
310+
return refs
311+
312+
def _extract_model_level_refs(self, node_name: str, node, manifest_nodes: dict) -> list[Ref]:
313+
"""Extract Ref objects from model-level FK constraints.
314+
315+
Args:
316+
node_name: The node unique ID (e.g. model.pkg.orders)
317+
node: The manifest node object
318+
manifest_nodes: Dict of all manifest nodes
319+
320+
Returns:
321+
List of Ref objects
322+
323+
"""
324+
refs = []
325+
326+
if not hasattr(node, "constraints") or not node.constraints:
327+
return refs
328+
329+
for constraint in node.constraints:
330+
if constraint.type.value != "foreign_key":
331+
continue
332+
if not getattr(constraint, "to", None):
333+
continue
334+
if not getattr(constraint, "columns", None):
335+
continue
336+
337+
to_node_id = _resolve_to_node_id(constraint.to, manifest_nodes)
338+
if not to_node_id:
339+
continue
340+
341+
to_columns = list(getattr(constraint, "to_columns", None) or constraint.columns)
342+
node_meta = getattr(node, "meta", None) or {}
343+
constraint_name = getattr(constraint, "name", None) or None
344+
relationship_type = _get_relationship_type(
345+
node_meta.get("relationship_types", {}).get(constraint_name) or ""
346+
)
347+
relationship_label = node_meta.get("relationship_labels", {}).get(constraint_name) or node_meta.get(
348+
"relationship_label"
349+
)
350+
351+
refs.append(
352+
Ref(
353+
name=constraint_name or node_name,
354+
table_map=[to_node_id, node_name],
355+
column_map=(to_columns, list(constraint.columns)),
356+
type=relationship_type,
357+
relationship_label=relationship_label,
358+
)
359+
)
360+
361+
return refs

dbterd/adapters/algos/semantic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ def get_relationships(self, manifest: Manifest) -> list[Ref]:
259259
name=primary_entity.semantic_model,
260260
table_map=(primary_entity.model, foreign_entity.model),
261261
column_map=(
262-
primary_entity.column_name,
263-
foreign_entity.column_name,
262+
[primary_entity.column_name],
263+
[foreign_entity.column_name],
264264
),
265265
type=primary_entity.relationship_type,
266266
)
@@ -285,8 +285,8 @@ def get_relationships_from_metadata(self, data: list) -> list[Ref]:
285285
name=primary_entity.semantic_model,
286286
table_map=(primary_entity.model, foreign_entity.model),
287287
column_map=(
288-
primary_entity.column_name,
289-
foreign_entity.column_name,
288+
[primary_entity.column_name],
289+
[foreign_entity.column_name],
290290
),
291291
type=primary_entity.relationship_type,
292292
)

0 commit comments

Comments
 (0)