Skip to content

Commit 6ae764d

Browse files
committed
add materializable sqlmodel (wip)
1 parent 796f215 commit 6ae764d

File tree

4 files changed

+62
-37
lines changed

4 files changed

+62
-37
lines changed

dlt/destinations/dataset/dataset.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,13 @@ def schema(self) -> Schema:
9090
def sqlglot_schema(self) -> SQLGlotSchema:
9191
# NOTE: no cache for now, it is probably more expensive to compute the current schema hash
9292
# to see wether this is stale than to compute a new sqlglot schema
93-
dialect: str = self.sql_client.capabilities.sqlglot_dialect
94-
return lineage.create_sqlglot_schema(self.schema, self.dataset_name, dialect=dialect)
93+
return lineage.create_sqlglot_schema(
94+
self.schema, self.dataset_name, dialect=self.sqlglot_dialect
95+
)
96+
97+
@property
98+
def sqlglot_dialect(self) -> str:
99+
return self.sql_client.capabilities.sqlglot_dialect
95100

96101
@property
97102
def dataset_name(self) -> str:

dlt/destinations/dataset/relation.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,10 @@ def query(self, pretty: bool = False) -> str:
7979
"Must be an SQL SELECT statement."
8080
)
8181

82-
return query.sql(
83-
dialect=self._dataset.sql_client.capabilities.sqlglot_dialect, pretty=pretty
84-
)
82+
return query.sql(dialect=self._dataset.sqlglot_dialect, pretty=pretty)
83+
84+
def query_dialect(self) -> str:
85+
return self._dataset.sqlglot_dialect
8586

8687
def _query(self) -> sge.Query:
8788
"""Returns a compliant with dlt schema in the relation.
@@ -249,16 +250,11 @@ def __init__(
249250
self._selected_columns = selected_columns
250251

251252
def _query(self) -> sge.Query:
252-
destination_dialect = self._dataset.sql_client.capabilities.sqlglot_dialect
253-
254253
# TODO reimplement this using SQLGLot instead of passing strings
255254
if self._provided_query:
256255
return cast(
257256
sge.Query,
258-
sqlglot.parse_one(
259-
self._provided_query,
260-
dialect=self._provided_query_dialect or destination_dialect,
261-
),
257+
sqlglot.parse_one(self._provided_query, dialect=self.query_dialect()),
262258
)
263259

264260
dataset_schema = self._dataset.schema
@@ -285,11 +281,14 @@ def _query(self) -> sge.Query:
285281
),
286282
)
287283

284+
def query_dialect(self) -> str:
285+
return self._provided_query_dialect or self._dataset.sqlglot_dialect
286+
288287
def __copy__(self) -> Self:
289288
return self.__class__(
290289
readable_dataset=self._dataset,
291290
provided_query=self._provided_query,
292-
provided_query_dialect=self._provided_query_dialect,
291+
provided_query_dialect=self.query_dialect(),
293292
table_name=self._table_name,
294293
limit=self._limit,
295294
selected_columns=self._selected_columns,

dlt/extract/hints.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from dlt.extract.items_transform import ValidateItem
5656
from dlt.extract.utils import ensure_table_schema_columns, ensure_table_schema_columns_hint
5757
from dlt.extract.validation import create_item_validator
58-
from dlt.common.data_writers import TDataItemFormat
5958

6059
import sqlglot
6160

@@ -99,9 +98,17 @@ def __init__(
9998
self.create_table_variant = create_table_variant
10099

101100

102-
class SqlModel(NamedTuple):
103-
query: str
104-
dialect: Optional[str] = None
101+
class SqlModel:
102+
"""
103+
A SqlModel is a named tuple that contains a query and a dialect.
104+
It is used to represent a SQL query and the dialect to use for parsing it.
105+
"""
106+
107+
__slots__ = ("query", "dialect")
108+
109+
def __init__(self, query: str, dialect: Optional[str] = None) -> None:
110+
self.query = query
111+
self.dialect = dialect
105112

106113
@classmethod
107114
def from_query_string(cls, query: str, dialect: Optional[str] = None) -> "SqlModel":

dlt/transformations/transformation.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,35 @@
2929
TTableFormat,
3030
TTableReferenceParam,
3131
)
32-
from dlt.common.destination.dataset import SupportsReadableRelation
3332
from dlt.transformations.configuration import TransformationConfiguration
3433
from dlt.common.utils import get_callable_name
3534
from dlt.extract.exceptions import CurrentSourceNotAvailable
3635

3736

37+
class MaterializableSqlModel(SqlModel):
38+
# NOTE: we could forward all data access methods to this class
39+
__slots__ = ("relation",)
40+
41+
def __init__(
42+
self,
43+
relation: Optional[BaseReadableDBAPIRelation] = None,
44+
) -> None:
45+
super().__init__(relation.query(), relation.query_dialect())
46+
self.relation = relation
47+
48+
@classmethod
49+
def from_relation(cls, relation: BaseReadableDBAPIRelation) -> "MaterializableSqlModel":
50+
return cls(relation=relation)
51+
52+
def compute_columns(self) -> TTableSchemaColumns:
53+
computed_columns, _ = self.relation._compute_columns_schema(
54+
infer_sqlglot_schema=True,
55+
allow_anonymous_columns=True,
56+
allow_partial=True,
57+
)
58+
return computed_columns
59+
60+
3861
class DltTransformationResource(DltResource):
3962
def __init__(self, *args: Any, **kwds: Any) -> None:
4063
super().__init__(*args, **kwds)
@@ -113,11 +136,12 @@ def transformation_function(*args: Any, **kwargs: Any) -> Iterator[TDataItems]:
113136
# Call the transformation function
114137
transformation_result: Any = func(*args, **kwargs)
115138

116-
# If a string is returned, treat it as a SQL query
117-
if isinstance(transformation_result, str):
118-
transformation_result = datasets[0](transformation_result)
119-
120-
if not isinstance(transformation_result, BaseReadableDBAPIRelation):
139+
# If a string is returned, construct relation from first dataset from it
140+
if isinstance(transformation_result, BaseReadableDBAPIRelation):
141+
relation = transformation_result
142+
elif isinstance(transformation_result, str):
143+
relation = datasets[0](transformation_result)
144+
else:
121145
raise TransformationInvalidReturnTypeException(
122146
resource_name,
123147
"Sql Transformation %s returned an invalid type: %s. Please either return a valid"
@@ -126,24 +150,14 @@ def transformation_function(*args: Any, **kwargs: Any) -> Iterator[TDataItems]:
126150
% (name, type(transformation_result)),
127151
)
128152

129-
# Compute columns schema
130-
computed_columns, _ = transformation_result._compute_columns_schema(
131-
infer_sqlglot_schema=True,
132-
allow_anonymous_columns=True,
133-
allow_partial=True,
134-
)
135-
select_dialect = datasets[0].sql_client.capabilities.sqlglot_dialect
136-
select_query = transformation_result.query()
137-
all_columns = {**computed_columns, **(columns or {})}
153+
sql_model = MaterializableSqlModel.from_relation(relation)
138154

139155
if not should_materialize:
140-
yield dlt.mark.with_hints(
141-
SqlModel(select_query, dialect=select_dialect),
142-
hints=make_hints(columns=all_columns),
143-
)
156+
yield sql_model
144157
else:
145-
for chunk in datasets[0](select_query).iter_arrow(chunk_size=config.buffer_max_items):
146-
yield dlt.mark.with_hints(chunk, hints=make_hints(columns=all_columns))
158+
column_hints = make_hints(columns=sql_model.compute_columns())
159+
for chunk in relation.iter_arrow(chunk_size=config.buffer_max_items):
160+
yield dlt.mark.with_hints(chunk, hints=column_hints)
147161

148162
return dlt.resource( # type: ignore[return-value]
149163
name=name,

0 commit comments

Comments
 (0)