Skip to content

Commit 3ca690c

Browse files
committed
convert transformation functions to need yield instead of return
1 parent fdc3079 commit 3ca690c

File tree

3 files changed

+64
-64
lines changed

3 files changed

+64
-64
lines changed

dlt/transformations/transformation.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from typing import Callable, Any, Optional, Type, Iterator, List
44

55
import dlt
6+
import sqlglot
67

78
from dlt.common.configuration.inject import get_fun_last_config, get_fun_spec
8-
from dlt.common.reflection.inspect import isgeneratorfunction
99
from dlt.common.typing import TDataItems, TTableHintTemplate
1010
from dlt.common import logger
1111

@@ -16,7 +16,6 @@
1616
from dlt.transformations.typing import TTransformationFunParams
1717
from dlt.transformations.exceptions import (
1818
TransformationException,
19-
TransformationInvalidReturnTypeException,
2019
IncompatibleDatasetsException,
2120
)
2221
from dlt.pipeline.exceptions import PipelineConfigMissing
@@ -32,7 +31,6 @@
3231
from dlt.transformations.configuration import TransformationConfiguration
3332
from dlt.common.utils import get_callable_name
3433
from dlt.extract.exceptions import CurrentSourceNotAvailable
35-
from dlt.common.schema.typing import TPartialTableSchema
3634
from dlt.extract.pipe_iterator import DataItemWithMeta
3735

3836

@@ -82,7 +80,6 @@ def make_transformation_resource(
8280
section: Optional[TTableHintTemplate[str]],
8381
) -> DltTransformationResource:
8482
resource_name = name if name and not callable(name) else get_callable_name(func)
85-
is_regular_resource = isgeneratorfunction(func)
8683

8784
if spec and not issubclass(spec, TransformationConfiguration):
8885
raise TransformationException(
@@ -92,16 +89,53 @@ def make_transformation_resource(
9289

9390
@wraps(func)
9491
def transformation_function(*args: Any, **kwargs: Any) -> Iterator[TDataItems]:
95-
config: TransformationConfiguration = (
96-
get_fun_last_config(func) or get_fun_spec(func)() # type: ignore[assignment]
97-
)
98-
9992
# Collect all datasets from args and kwargs
10093
all_arg_values = list(args) + list(kwargs.values())
10194
datasets: List[ReadableDBAPIDataset] = [
10295
arg for arg in all_arg_values if isinstance(arg, ReadableDBAPIDataset)
10396
]
10497

98+
# get first item from gen and see what we're dealing with
99+
gen = func(*args, **kwargs)
100+
original_first_item = next(gen)
101+
102+
# unwrap if needed
103+
meta = None
104+
unwrapped_item = original_first_item
105+
relation = None
106+
if isinstance(original_first_item, DataItemWithMeta):
107+
meta = original_first_item.meta
108+
unwrapped_item = original_first_item.data
109+
110+
# catch the two cases where we get a relation from the transformation function
111+
# NOTE: we only process the first item, all other things that are still in the generator are ignored
112+
if isinstance(unwrapped_item, BaseReadableDBAPIRelation):
113+
relation = unwrapped_item
114+
# we see if the string is a valid sql query, if so we need a dataset
115+
elif isinstance(unwrapped_item, str):
116+
try:
117+
sqlglot.parse_one(unwrapped_item)
118+
if len(datasets) == 0:
119+
raise IncompatibleDatasetsException(
120+
resource_name,
121+
"No datasets found in transformation function arguments. Please supply all"
122+
" used datasets via transform function arguments.",
123+
)
124+
else:
125+
relation = datasets[0](unwrapped_item)
126+
except sqlglot.errors.ParseError:
127+
pass
128+
129+
# we have something else, so fall back to regular resource behavior
130+
if not relation:
131+
yield original_first_item
132+
yield from gen
133+
return
134+
135+
config: TransformationConfiguration = (
136+
get_fun_last_config(func) or get_fun_spec(func)() # type: ignore[assignment]
137+
)
138+
105139
# Warn if Incremental arguments are present
106140
for arg_name, param in inspect.signature(func).parameters.items():
107141
if param.annotation is Incremental or isinstance(param.default, Incremental):
@@ -138,31 +172,8 @@ def transformation_function(*args: Any, **kwargs: Any) -> Iterator[TDataItems]:
138172
# respect always materialize config
139173
should_materialize = should_materialize or config.always_materialize
140174

141-
# Call the transformation function
142-
transformation_result: Any = func(*args, **kwargs)
143-
144-
# unwrap meta
145-
meta = None
146-
if isinstance(transformation_result, DataItemWithMeta):
147-
meta = transformation_result.meta
148-
transformation_result = transformation_result.data
149-
150-
# If a string is returned, construct relation from first dataset from it
151-
if isinstance(transformation_result, BaseReadableDBAPIRelation):
152-
relation = transformation_result
153-
elif isinstance(transformation_result, str):
154-
relation = datasets[0](transformation_result)
155-
else:
156-
raise TransformationInvalidReturnTypeException(
157-
resource_name,
158-
"Sql Transformation %s returned an invalid type: %s. Please either return a valid"
159-
" sql string or Ibis / data frame expression from a dataset. If you want to return"
160-
" data (data frames / arrow table), please yield those, not return."
161-
% (name, type(transformation_result)),
162-
)
163-
175+
# build model if needed
164176
sql_model = MaterializableSqlModel.from_relation(relation)
165-
166177
if not should_materialize:
167178
if meta:
168179
yield DataItemWithMeta(meta, sql_model)
@@ -188,6 +199,4 @@ def transformation_function(*args: Any, **kwargs: Any) -> Iterator[TDataItems]:
188199
section=section,
189200
_impl_cls=DltTransformationResource,
190201
_base_spec=TransformationConfiguration,
191-
)(
192-
func if is_regular_resource else transformation_function # type: ignore[arg-type]
193-
)
202+
)(transformation_function)

tests/transformations/test_transformation_decorator.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,23 @@
2222

2323

2424
def test_no_datasets_used() -> None:
25+
# valid sql string with out dataset will raise
2526
with pytest.raises(IncompatibleDatasetsException) as excinfo:
2627

2728
@dlt.transformation()
2829
def transform() -> Any:
29-
return {"some": "data"}
30+
yield "SELECT * FROM table1"
3031

3132
list(transform())
3233

33-
assert "No datasets detected in transformation. Please supply all used datasets via" in str(
34-
excinfo.value
35-
)
34+
assert "No datasets found in transformation function arguments" in str(excinfo.value)
35+
36+
# invalid sql string without dataset will be interpreted as string item
37+
@dlt.transformation()
38+
def other_transform() -> Any:
39+
yield "Hello I am a string"
40+
41+
assert list(other_transform()) == ["Hello I am a string"]
3642

3743

3844
def test_iterator_function_as_transform_function() -> None:
@@ -44,19 +50,6 @@ def transform(dataset: SupportsReadableDataset[Any]) -> Any:
4450
assert list(transform(dlt.dataset("duckdb", "dataset_name"))) == [{"some": "data"}]
4551

4652

47-
def test_incorrect_transform_function_return_type() -> None:
48-
p = dlt.pipeline("test_pipeline", destination="duckdb")
49-
50-
@dlt.transformation()
51-
def transform(dataset: SupportsReadableDataset[Any]) -> Any:
52-
return {"some": "data"}
53-
54-
with pytest.raises(PipelineStepFailed) as excinfo:
55-
p.run(transform(dlt.dataset(dlt.destinations.duckdb("input_data"), "dataset_name")))
56-
57-
assert "Please either return a valid sql string or" in str(excinfo.value)
58-
59-
6053
def test_incremental_argument_is_not_supported(caplog: LogCaptureFixture) -> None:
6154
# test incremental default arg
6255
with patch.object(logger, "warning") as mock_warning:
@@ -68,7 +61,7 @@ def transform_1(
6861
dataset: SupportsReadableDataset[Any],
6962
incremental_arg=dlt.sources.incremental("col1"),
7063
) -> Any:
71-
return "SELECT col1 FROM table1"
64+
yield "SELECT col1 FROM table1"
7265

7366
list(transform_1(dlt.dataset("duckdb", "dataset_name")))
7467

@@ -127,7 +120,7 @@ def default_spec(dataset: dlt.Dataset):
127120
assert type(config) is not TransformationConfiguration
128121
# config got passed
129122
assert config.buffer_max_items == 100
130-
return "SELECT col1 FROM table1"
123+
yield "SELECT col1 FROM table1"
131124

132125
schema = Schema("_data")
133126
schema.update_table(new_table("table1", columns=[{"name": "col1", "data_type": "text"}]))
@@ -146,7 +139,7 @@ def default_transformation_with_args(
146139
dataset: dlt.Dataset, last_id: str = dlt.config.value, limit: int = 5
147140
):
148141
assert last_id == "test_last_id"
149-
return dataset.table1[["col1"]]
142+
yield dataset.table1[["col1"]]
150143

151144
spec = get_fun_spec(default_transformation_with_args)
152145
assert "last_id" in spec().get_resolvable_fields()
@@ -178,7 +171,7 @@ def default_transformation_spec(
178171
assert limit == 100
179172

180173
table1_ = dataset(f"SELECT * FROM table1 WHERE col1 = '{last_idx}' LIMIT {limit}")
181-
return table1_
174+
yield table1_
182175

183176
assert default_transformation_spec.name == "default_name_ovr"
184177
assert default_transformation_spec.section == "default_name_ovr"
@@ -191,8 +184,6 @@ def default_transformation_spec(
191184
assert isinstance(model, SqlModel)
192185
query = model.query
193186
# make sure we have our args in query
194-
print(model)
195-
print(query)
196187
assert "uniq_last_id" in query
197188
assert "100" in query
198189

tests/transformations/test_transformations.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ def test_simple_query_transformations(
4242

4343
@dlt.transformation()
4444
def copied_purchases(dataset: SupportsReadableDataset[Any]) -> Any:
45-
return """SELECT * FROM purchases LIMIT 3"""
45+
yield """SELECT * FROM purchases LIMIT 3"""
4646

4747
elif transformation_type == "relation":
4848

4949
@dlt.transformation()
5050
def copied_purchases(dataset: SupportsReadableDataset[Any]) -> Any:
51-
return dataset["purchases"].limit(3)
51+
yield dataset["purchases"].limit(3)
5252

5353
# transform into transformed dataset
5454
os.environ["ALWAYS_MATERIALIZE"] = str(always_materialize)
@@ -87,12 +87,12 @@ def test_transformations_with_supplied_hints(
8787
# we can now transform this table twice, one with changed hints and once with the original hints
8888
@dlt.transformation()
8989
def inventory_original(dataset: SupportsReadableDataset[Any]) -> Any:
90-
return dataset["inventory"]
90+
yield dataset["inventory"]
9191

9292
@dlt.transformation()
9393
def inventory_more_precise(dataset: SupportsReadableDataset[Any]) -> Any:
9494
hints = make_hints(columns=[{"name": "price", "precision": 20, "scale": 2}])
95-
return dlt.mark.with_hints(dataset["inventory"], hints=hints)
95+
yield dlt.mark.with_hints(dataset["inventory"], hints=hints)
9696

9797
dest_p.run([inventory_original(fruit_p.dataset()), inventory_more_precise(fruit_p.dataset())])
9898

@@ -119,7 +119,7 @@ def test_extract_without_source_name_or_pipeline(
119119

120120
@dlt.transformation()
121121
def buffer_size_test(dataset: SupportsReadableDataset[Any]) -> Any:
122-
return dataset["customers"]
122+
yield dataset["customers"]
123123

124124
# transformations switch to model extraction
125125
fruit_p.deactivate()
@@ -139,7 +139,7 @@ def test_extract_without_destination(destination_config: DestinationTestConfigur
139139

140140
@dlt.transformation()
141141
def extract_test(dataset: SupportsReadableDataset[Any]) -> Any:
142-
return dataset["customers"]
142+
yield dataset["customers"]
143143

144144
pipeline_no_destination = dlt.pipeline(pipeline_name="no_destination")
145145
pipeline_no_destination._destination = None

0 commit comments

Comments
 (0)