Skip to content

Commit d77afdd

Browse files
committed
--wip-- [skip ci]
1 parent 709bd11 commit d77afdd

32 files changed

+17216
-10
lines changed

core/dbt/cli/requires.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from dbt.exceptions import DbtProjectError, FailFastError
3232
from dbt.flags import get_flag_dict, get_flags, set_flags
3333
from dbt.mp_context import get_mp_context
34+
from dbt.openlineage.handler import OpenLineageHandler
35+
from dbt.openlineage.common.utils import is_runnable_dbt_command
3436
from dbt.parser.manifest import parse_manifest
3537
from dbt.plugins import set_up_plugin_manager
3638
from dbt.profiler import profiler
@@ -82,8 +84,13 @@ def wrapper(*args, **kwargs):
8284
# Reset invocation_id for each 'invocation' of a dbt command (can happen multiple times in a single process)
8385
reset_invocation_id()
8486

85-
# Logging
87+
# OpenLineage
88+
ol_handler = OpenLineageHandler(ctx)
8689
callbacks = ctx.obj.get("callbacks", [])
90+
if is_runnable_dbt_command(flags):
91+
callbacks.append(ol_handler.handle)
92+
93+
# Logging
8794
setup_event_logger(flags=flags, callbacks=callbacks)
8895

8996
# Tracking

core/dbt/events/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from typing import Any, Dict, Set
22

33
import dbt.adapters.events.types as adapter_dbt_event_types
4+
import dbt.adapters.events.adapter_types_pb2 as adapter_types_pb2
5+
import dbt.events.core_types_pb2 as core_dbt_event_types_pb2
46
import dbt.events.types as core_dbt_event_types
57
import dbt_common.events.types as dbt_event_types
8+
import dbt_common.events.types_pb2 as dbt_event_types_pb2
69

710
ALL_EVENT_TYPES: Dict[str, Any] = {
811
**dbt_event_types.__dict__,
@@ -13,3 +16,10 @@
1316
ALL_EVENT_NAMES: Set[str] = set(
1417
[name for name, cls in ALL_EVENT_TYPES.items() if isinstance(cls, type)]
1518
)
19+
20+
21+
ALL_PROTO_TYPES: Dict[str, Any] = {
22+
**dbt_event_types_pb2.__dict__,
23+
**core_dbt_event_types_pb2.__dict__,
24+
**adapter_types_pb2.__dict__
25+
}

core/dbt/events/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2416,3 +2416,13 @@ def code(self) -> str:
24162416

24172417
def message(self) -> str:
24182418
return f"Artifacts skipped for command : {self.msg}"
2419+
2420+
class OpenLineageException(WarnLevel):
2421+
def code(self) -> str:
2422+
return "Z064"
2423+
2424+
def message(self):
2425+
return (
2426+
f"Encountered an error while creating OpenLineageEvent: {self.exc}\n"
2427+
f"{self.exc_info}"
2428+
)

core/dbt/openlineage/__init__.py

Whitespace-only changes.

core/dbt/openlineage/common/__init__.py

Whitespace-only changes.
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import enum
2+
3+
from dbt.adapters.contracts.connection import Credentials
4+
5+
from openlineage.client.client import OpenLineageClient
6+
from openlineage.client.uuid import generate_new_uuid
7+
from openlineage.client.event_v2 import Dataset, Job, Run, RunEvent, RunState, InputDataset, OutputDataset
8+
from openlineage.client.facet_v2 import (
9+
BaseFacet,
10+
DatasetFacet,
11+
InputDatasetFacet,
12+
JobFacet,
13+
OutputDatasetFacet,
14+
column_lineage_dataset,
15+
data_quality_assertions_dataset,
16+
datasource_dataset,
17+
documentation_dataset,
18+
job_type_job,
19+
output_statistics_output_dataset,
20+
parent_run,
21+
schema_dataset,
22+
sql_job,
23+
error_message_run,
24+
)
25+
from typing import List
26+
27+
from dbt.events.types import OpenLineageException
28+
from dbt.config.project import Project
29+
30+
from core.dbt.artifacts.resources.types import NodeType
31+
from core.dbt.contracts.graph.manifest import Manifest
32+
from core.dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ModelNode, GenericTestNode, SingularTestNode, \
33+
SeedNode, ManifestSQLNode
34+
35+
36+
class Adapter(enum.Enum):
37+
# supported adapters.
38+
BIGQUERY = "bigquery"
39+
SNOWFLAKE = "snowflake"
40+
REDSHIFT = "redshift"
41+
SPARK = "spark"
42+
POSTGRES = "postgres"
43+
DATABRICKS = "databricks"
44+
SQLSERVER = "sqlserver"
45+
DREMIO = "dremio"
46+
ATHENA = "athena"
47+
DUCKDB = "duckdb"
48+
TRINO = "trino"
49+
50+
@staticmethod
51+
def adapters() -> str:
52+
# String representation of all supported adapter names
53+
return ",".join([f"`{x.value}`" for x in list(Adapter)])
54+
55+
class SparkConnectionMethod(enum.Enum):
56+
THRIFT = "thrift"
57+
ODBC = "odbc"
58+
HTTP = "http"
59+
60+
@staticmethod
61+
def methods():
62+
return [x.value for x in SparkConnectionMethod]
63+
64+
65+
def extract_schema_dataset_facet(node: ManifestNode) -> List[schema_dataset.SchemaDatasetFacetFields]:
66+
if node.resource_type == NodeType.Seed:
67+
return _extract_schema_dataset_from_seed(node)
68+
else:
69+
return _extract_schema_dataset_facet_from_manifest_sql_node(node)
70+
71+
def _extract_schema_dataset_facet_from_manifest_sql_node(manifest_sql_node: ManifestSQLNode) -> List[schema_dataset.SchemaDatasetFacetFields]:
72+
schema_fields = []
73+
for column_info in manifest_sql_node.columns.values():
74+
description = column_info.description
75+
name = column_info.name
76+
data_type = column_info.data_type or ""
77+
schema_fields.append(schema_dataset.SchemaDatasetFacetFields(name=name, description=description, type=data_type))
78+
return schema_fields
79+
80+
def _extract_schema_dataset_from_seed(seed: SeedNode) -> List[schema_dataset.SchemaDatasetFacetFields]:
81+
schema_fields = []
82+
for col_name in seed.config.column_types:
83+
col_type = seed.config.column_types[col_name]
84+
schema_fields.append(schema_dataset.SchemaDatasetFacetFields(name=col_name, type=col_type))
85+
return schema_fields
86+
87+
88+
def get_test_column(test_node: GenericTestNode | SingularTestNode) -> List[str]:
89+
if test_node.column_name:
90+
return [test_node.column_name]
91+
elif test_node.test_metadata.kwargs:
92+
return [test_node.test_metadata.kwargs["column_name"]]
93+
else:
94+
return []
95+
96+
97+
def get_model_inputs(node_unique_id: str, manifest: Manifest) -> List[ModelNode | SourceDefinition]:
98+
upstreams = []
99+
input_node_ids = manifest.parent_map.get(node_unique_id, [])
100+
for input_node_id in input_node_ids:
101+
if input_node_id.startswith("source."):
102+
upstreams.append(manifest.sources[input_node_id])
103+
else:
104+
upstreams.append(manifest.nodes[input_node_id])
105+
return upstreams
106+
107+
108+
def node_to_dataset(node: [ManifestNode, SourceDefinition], dataset_namespace: str) -> Dataset:
109+
facets = {
110+
"dataSource": datasource_dataset.DatasourceDatasetFacet(
111+
name=dataset_namespace, uri=dataset_namespace
112+
),
113+
"schema": schema_dataset.SchemaDatasetFacet(
114+
fields=extract_schema_dataset_facet(node)
115+
),
116+
"documentation": documentation_dataset.DocumentationDatasetFacet(
117+
description=node.description
118+
),
119+
}
120+
node_fqn = ".".join(node.fqn)
121+
return Dataset(namespace=dataset_namespace, name=node_fqn, facets=facets)
122+
123+
def get_test_column(test_node: GenericTestNode | SingularTestNode) -> List[str]:
124+
if test_node.column_name:
125+
return [test_node.column_name]
126+
elif test_node.test_metadata.kwargs:
127+
return [test_node.test_metadata.kwargs["column_name"]]
128+
else:
129+
return []
130+
131+
132+
def extract_namespace(adapter: Credentials) -> str:
133+
"""
134+
Extract namespace from profile's type
135+
"""
136+
if adapter.type == Adapter.SNOWFLAKE.value:
137+
return f"snowflake://{_fix_account_name(adapter.account)}"
138+
elif adapter.type == Adapter.BIGQUERY.value:
139+
return "bigquery"
140+
elif adapter.type == Adapter.REDSHIFT.value:
141+
return f"redshift://{adapter.host}:{adapter.port}"
142+
elif adapter.type == Adapter.POSTGRES.value:
143+
return f"postgres://{adapter.host}:{adapter.port}"
144+
elif adapter.type == Adapter.TRINO.value:
145+
return f"trino://{adapter.host}:{adapter.port}"
146+
elif adapter.type == Adapter.DATABRICKS.value:
147+
return f"databricks://{adapter.host}"
148+
elif adapter.type == Adapter.SQLSERVER.value:
149+
return f"mssql://{adapter.server}:{adapter.port}"
150+
elif adapter.type == Adapter.DREMIO.value:
151+
return f"dremio://{adapter.software_host}:{adapter.port}"
152+
elif adapter.type == Adapter.ATHENA.value:
153+
return f"awsathena://athena.{adapter.region_name}.amazonaws.com"
154+
elif adapter.type == Adapter.DUCKDB.value:
155+
return f"duckdb://{adapter.path}"
156+
elif adapter.type == Adapter.SPARK.value:
157+
port = ""
158+
if hasattr(adapter, "port"):
159+
port = f":{adapter.port}"
160+
elif adapter.method in [
161+
SparkConnectionMethod.HTTP.value,
162+
SparkConnectionMethod.ODBC.value,
163+
]:
164+
port = "443"
165+
elif adapter.method == SparkConnectionMethod.THRIFT.value:
166+
port = "10001"
167+
168+
if adapter.method in SparkConnectionMethod.methods():
169+
return f"spark://{adapter.host}{port}"
170+
else:
171+
raise NotImplementedError(
172+
f"Connection method `{adapter.method}` is not " f"supported for spark adapter."
173+
)
174+
else:
175+
raise NotImplementedError(
176+
f"Only {Adapter.adapters()} adapters are supported right now. " f"Passed {adapter.type}"
177+
)
178+
179+
180+
def _fix_account_name(name: str) -> str:
181+
if not any(word in name for word in ["-", "_"]):
182+
# If there is neither '-' nor '_' in the name, we append `.us-west-1.aws`
183+
return f"{name}.us-west-1.aws"
184+
185+
if "." in name:
186+
# Logic for account locator with dots remains unchanged
187+
spl = name.split(".")
188+
if len(spl) == 1:
189+
account = spl[0]
190+
region, cloud = "us-west-1", "aws"
191+
elif len(spl) == 2:
192+
account, region = spl
193+
cloud = "aws"
194+
else:
195+
account, region, cloud = spl
196+
return f"{account}.{region}.{cloud}"
197+
198+
# Check for existing accounts with cloud names
199+
if cloud := next((c for c in ["aws", "gcp", "azure"] if c in name), ""):
200+
parts = name.split(cloud)
201+
account = parts[0].strip("-_.")
202+
203+
if not (region := parts[1].strip("-_.").replace("_", "-")):
204+
return name
205+
return f"{account}.{region}.{cloud}"
206+
207+
# Default case, return the original name
208+
return name

0 commit comments

Comments
 (0)