Skip to content

Commit 9e2209b

Browse files
committed
--wip-- [skip ci]
1 parent 87584c7 commit 9e2209b

31 files changed

+16582
-9
lines changed

core/dbt/cli/requires.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from dbt.exceptions import DbtProjectError, FailFastError
3232
from dbt.flags import get_flag_dict, get_flags, set_flags
3333
from dbt.mp_context import get_mp_context
34+
from dbt.openlineage.common.utils import is_runnable_dbt_command
35+
from dbt.openlineage.handler import OpenLineageHandler
3436
from dbt.parser.manifest import parse_manifest
3537
from dbt.plugins import set_up_plugin_manager
3638
from dbt.profiler import profiler
@@ -82,8 +84,13 @@ def wrapper(*args, **kwargs):
8284
# Reset invocation_id for each 'invocation' of a dbt command (can happen multiple times in a single process)
8385
reset_invocation_id()
8486

85-
# Logging
87+
# OpenLineage
88+
ol_handler = OpenLineageHandler(ctx)
8689
callbacks = ctx.obj.get("callbacks", [])
90+
if is_runnable_dbt_command(flags):
91+
callbacks.append(ol_handler.handle)
92+
93+
# Logging
8794
setup_event_logger(flags=flags, callbacks=callbacks)
8895

8996
# Tracking

core/dbt/events/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from typing import Any, Dict, Set
22

3+
import dbt.adapters.events.adapter_types_pb2 as adapter_types_pb2
34
import dbt.adapters.events.types as adapter_dbt_event_types
5+
import dbt.events.core_types_pb2 as core_dbt_event_types_pb2
46
import dbt.events.types as core_dbt_event_types
57
import dbt_common.events.types as dbt_event_types
8+
import dbt_common.events.types_pb2 as dbt_event_types_pb2
69

710
ALL_EVENT_TYPES: Dict[str, Any] = {
811
**dbt_event_types.__dict__,
@@ -13,3 +16,10 @@
1316
ALL_EVENT_NAMES: Set[str] = set(
1417
[name for name, cls in ALL_EVENT_TYPES.items() if isinstance(cls, type)]
1518
)
19+
20+
21+
ALL_PROTO_TYPES: Dict[str, Any] = {
22+
**dbt_event_types_pb2.__dict__,
23+
**core_dbt_event_types_pb2.__dict__,
24+
**adapter_types_pb2.__dict__,
25+
}

core/dbt/events/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2416,3 +2416,14 @@ def code(self) -> str:
24162416

24172417
def message(self) -> str:
24182418
return f"Artifacts skipped for command : {self.msg}"
2419+
2420+
2421+
class OpenLineageException(WarnLevel):
2422+
def code(self) -> str:
2423+
return "Z064"
2424+
2425+
def message(self):
2426+
return (
2427+
f"Encountered an error while creating OpenLineageEvent: {self.exc}\n"
2428+
f"{self.exc_info}"
2429+
)

core/dbt/openlineage/__init__.py

Whitespace-only changes.

core/dbt/openlineage/common/__init__.py

Whitespace-only changes.
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import enum
2+
from typing import Any, List, Optional, Union, no_type_check
3+
4+
from openlineage.client.event_v2 import Dataset
5+
from openlineage.client.facet_v2 import (
6+
datasource_dataset,
7+
documentation_dataset,
8+
schema_dataset,
9+
)
10+
11+
from dbt.adapters.contracts.connection import Credentials
12+
from dbt.artifacts.resources.types import NodeType
13+
from dbt.contracts.graph.manifest import Manifest
14+
from dbt.contracts.graph.nodes import (
15+
GenericTestNode,
16+
ManifestNode,
17+
ManifestSQLNode,
18+
ModelNode,
19+
SeedNode,
20+
SingularTestNode,
21+
SourceDefinition,
22+
)
23+
24+
25+
class Adapter(enum.Enum):
26+
# supported adapters
27+
BIGQUERY = "bigquery"
28+
SNOWFLAKE = "snowflake"
29+
REDSHIFT = "redshift"
30+
SPARK = "spark"
31+
POSTGRES = "postgres"
32+
DATABRICKS = "databricks"
33+
SQLSERVER = "sqlserver"
34+
DREMIO = "dremio"
35+
ATHENA = "athena"
36+
DUCKDB = "duckdb"
37+
TRINO = "trino"
38+
39+
@staticmethod
40+
def adapters() -> str:
41+
# String representation of all supported adapter names
42+
return ",".join([f"`{x.value}`" for x in list(Adapter)])
43+
44+
45+
class SparkConnectionMethod(enum.Enum):
46+
THRIFT = "thrift"
47+
ODBC = "odbc"
48+
HTTP = "http"
49+
50+
@staticmethod
51+
def methods():
52+
return [x.value for x in SparkConnectionMethod]
53+
54+
55+
def extract_schema_dataset_facet(
56+
node: Union[ManifestNode, SourceDefinition],
57+
) -> List[schema_dataset.SchemaDatasetFacetFields]:
58+
if node.resource_type == NodeType.Seed:
59+
return _extract_schema_dataset_from_seed(node)
60+
else:
61+
return _extract_schema_dataset_facet_from_manifest_sql_node(node)
62+
63+
64+
def _extract_schema_dataset_facet_from_manifest_sql_node(
65+
manifest_sql_node: Union[ManifestNode, SourceDefinition],
66+
) -> List[schema_dataset.SchemaDatasetFacetFields]:
67+
schema_fields = []
68+
for column_info in manifest_sql_node.columns.values():
69+
description = column_info.description
70+
name = column_info.name
71+
data_type = column_info.data_type or ""
72+
schema_fields.append(
73+
schema_dataset.SchemaDatasetFacetFields(
74+
name=name, description=description, type=data_type
75+
)
76+
)
77+
return schema_fields
78+
79+
80+
def _extract_schema_dataset_from_seed(
81+
seed: SeedNode,
82+
) -> List[schema_dataset.SchemaDatasetFacetFields]:
83+
schema_fields = []
84+
for col_name in seed.config.column_types:
85+
col_type = seed.config.column_types[col_name]
86+
schema_fields.append(schema_dataset.SchemaDatasetFacetFields(name=col_name, type=col_type))
87+
return schema_fields
88+
89+
90+
def get_model_inputs(
91+
node_unique_id: str, manifest: Manifest
92+
) -> List[ManifestNode | SourceDefinition]:
93+
upstreams: List[ManifestNode | SourceDefinition] = []
94+
input_node_ids = manifest.parent_map.get(node_unique_id, [])
95+
for input_node_id in input_node_ids:
96+
if input_node_id.startswith("source."):
97+
upstreams.append(manifest.sources[input_node_id])
98+
else:
99+
upstreams.append(manifest.nodes[input_node_id])
100+
return upstreams
101+
102+
103+
def node_to_dataset(
104+
node: Union[ManifestNode, SourceDefinition], dataset_namespace: str
105+
) -> Dataset:
106+
facets = {
107+
"dataSource": datasource_dataset.DatasourceDatasetFacet(
108+
name=dataset_namespace, uri=dataset_namespace
109+
),
110+
"schema": schema_dataset.SchemaDatasetFacet(fields=extract_schema_dataset_facet(node)),
111+
"documentation": documentation_dataset.DocumentationDatasetFacet(
112+
description=node.description
113+
),
114+
}
115+
node_fqn = ".".join(node.fqn)
116+
return Dataset(namespace=dataset_namespace, name=node_fqn, facets=facets)
117+
118+
119+
def get_test_column(test_node: GenericTestNode) -> Optional[str]:
120+
return test_node.test_metadata.kwargs.get("column_name")
121+
122+
123+
@no_type_check
124+
def extract_namespace(adapter: Credentials) -> str:
125+
# Extract namespace from profile's type
126+
if adapter.type == Adapter.SNOWFLAKE.value:
127+
return f"snowflake://{_fix_account_name(adapter.account)}"
128+
elif adapter.type == Adapter.BIGQUERY.value:
129+
return "bigquery"
130+
elif adapter.type == Adapter.REDSHIFT.value:
131+
return f"redshift://{adapter.host}:{adapter.port}"
132+
elif adapter.type == Adapter.POSTGRES.value:
133+
return f"postgres://{adapter.host}:{adapter.port}"
134+
elif adapter.type == Adapter.TRINO.value:
135+
return f"trino://{adapter.host}:{adapter.port}"
136+
elif adapter.type == Adapter.DATABRICKS.value:
137+
return f"databricks://{adapter.host}"
138+
elif adapter.type == Adapter.SQLSERVER.value:
139+
return f"mssql://{adapter.server}:{adapter.port}"
140+
elif adapter.type == Adapter.DREMIO.value:
141+
return f"dremio://{adapter.software_host}:{adapter.port}"
142+
elif adapter.type == Adapter.ATHENA.value:
143+
return f"awsathena://athena.{adapter.region_name}.amazonaws.com"
144+
elif adapter.type == Adapter.DUCKDB.value:
145+
return f"duckdb://{adapter.path}"
146+
elif adapter.type == Adapter.SPARK.value:
147+
port = ""
148+
if hasattr(adapter, "port"):
149+
port = f":{adapter.port}"
150+
elif adapter.method in [
151+
SparkConnectionMethod.HTTP.value,
152+
SparkConnectionMethod.ODBC.value,
153+
]:
154+
port = "443"
155+
elif adapter.method == SparkConnectionMethod.THRIFT.value:
156+
port = "10001"
157+
158+
if adapter.method in SparkConnectionMethod.methods():
159+
return f"spark://{adapter.host}{port}"
160+
else:
161+
raise NotImplementedError(
162+
f"Connection method `{adapter.method}` is not " f"supported for spark adapter."
163+
)
164+
else:
165+
raise NotImplementedError(
166+
f"Only {Adapter.adapters()} adapters are supported right now. "
167+
f"Passed {adapter.type}"
168+
)
169+
170+
171+
def _fix_account_name(name: str) -> str:
172+
if not any(word in name for word in ["-", "_"]):
173+
# If there is neither '-' nor '_' in the name, we append `.us-west-1.aws`
174+
return f"{name}.us-west-1.aws"
175+
176+
if "." in name:
177+
# Logic for account locator with dots remains unchanged
178+
spl = name.split(".")
179+
if len(spl) == 1:
180+
account = spl[0]
181+
region, cloud = "us-west-1", "aws"
182+
elif len(spl) == 2:
183+
account, region = spl
184+
cloud = "aws"
185+
else:
186+
account, region, cloud = spl
187+
return f"{account}.{region}.{cloud}"
188+
189+
# Check for existing accounts with cloud names
190+
if cloud := next((c for c in ["aws", "gcp", "azure"] if c in name), ""):
191+
parts = name.split(cloud)
192+
account = parts[0].strip("-_.")
193+
194+
if not (region := parts[1].strip("-_.").replace("_", "-")):
195+
return name
196+
return f"{account}.{region}.{cloud}"
197+
198+
# Default case, return the original name
199+
return name

0 commit comments

Comments
 (0)