33from typing import Callable , Any , Optional , Type , Iterator , List
44
55import dlt
6+ import sqlglot
67
78from dlt .common .configuration .inject import get_fun_last_config , get_fun_spec
8- from dlt .common .reflection .inspect import isgeneratorfunction
99from dlt .common .typing import TDataItems , TTableHintTemplate
1010from dlt .common import logger
1111
1616from dlt .transformations .typing import TTransformationFunParams
1717from dlt .transformations .exceptions import (
1818 TransformationException ,
19- TransformationInvalidReturnTypeException ,
2019 IncompatibleDatasetsException ,
2120)
2221from dlt .pipeline .exceptions import PipelineConfigMissing
3231from dlt .transformations .configuration import TransformationConfiguration
3332from dlt .common .utils import get_callable_name
3433from dlt .extract .exceptions import CurrentSourceNotAvailable
35- from dlt .common .schema .typing import TPartialTableSchema
3634from dlt .extract .pipe_iterator import DataItemWithMeta
3735
3836
@@ -82,7 +80,6 @@ def make_transformation_resource(
8280 section : Optional [TTableHintTemplate [str ]],
8381) -> DltTransformationResource :
8482 resource_name = name if name and not callable (name ) else get_callable_name (func )
85- is_regular_resource = isgeneratorfunction (func )
8683
8784 if spec and not issubclass (spec , TransformationConfiguration ):
8885 raise TransformationException (
@@ -92,16 +89,53 @@ def make_transformation_resource(
9289
9390 @wraps (func )
9491 def transformation_function (* args : Any , ** kwargs : Any ) -> Iterator [TDataItems ]:
95- config : TransformationConfiguration = (
96- get_fun_last_config (func ) or get_fun_spec (func )() # type: ignore[assignment]
97- )
98-
9992 # Collect all datasets from args and kwargs
10093 all_arg_values = list (args ) + list (kwargs .values ())
10194 datasets : List [ReadableDBAPIDataset ] = [
10295 arg for arg in all_arg_values if isinstance (arg , ReadableDBAPIDataset )
10396 ]
10497
98+ # get first item from gen and see what we're dealing with
99+ gen = func (* args , ** kwargs )
100+ original_first_item = next (gen )
101+
102+ # unwrap if needed
103+ meta = None
104+ unwrapped_item = original_first_item
105+ relation = None
106+ if isinstance (original_first_item , DataItemWithMeta ):
107+ meta = original_first_item .meta
108+ unwrapped_item = original_first_item .data
109+
110+ # catch the two cases where we get a relation from the transformation function
111+ # NOTE: we only process the first item, all other things that are still in the generator are ignored
112+ if isinstance (unwrapped_item , BaseReadableDBAPIRelation ):
113+ relation = unwrapped_item
114+ # we see if the string is a valid sql query, if so we need a dataset
115+ elif isinstance (unwrapped_item , str ):
116+ try :
117+ sqlglot .parse_one (unwrapped_item )
118+ if len (datasets ) == 0 :
119+ raise IncompatibleDatasetsException (
120+ resource_name ,
121+ "No datasets found in transformation function arguments. Please supply all"
122+ " used datasets via transform function arguments." ,
123+ )
124+ else :
125+ relation = datasets [0 ](unwrapped_item )
126+ except sqlglot .errors .ParseError :
127+ pass
128+
129+ # we have something else, so fall back to regular resource behavior
130+ if not relation :
131+ yield original_first_item
132+ yield from gen
133+ return
134+
135+ config : TransformationConfiguration = (
136+ get_fun_last_config (func ) or get_fun_spec (func )() # type: ignore[assignment]
137+ )
138+
105139 # Warn if Incremental arguments are present
106140 for arg_name , param in inspect .signature (func ).parameters .items ():
107141 if param .annotation is Incremental or isinstance (param .default , Incremental ):
@@ -138,31 +172,8 @@ def transformation_function(*args: Any, **kwargs: Any) -> Iterator[TDataItems]:
138172 # respect always materialize config
139173 should_materialize = should_materialize or config .always_materialize
140174
141- # Call the transformation function
142- transformation_result : Any = func (* args , ** kwargs )
143-
144- # unwrap meta
145- meta = None
146- if isinstance (transformation_result , DataItemWithMeta ):
147- meta = transformation_result .meta
148- transformation_result = transformation_result .data
149-
150- # If a string is returned, construct relation from first dataset from it
151- if isinstance (transformation_result , BaseReadableDBAPIRelation ):
152- relation = transformation_result
153- elif isinstance (transformation_result , str ):
154- relation = datasets [0 ](transformation_result )
155- else :
156- raise TransformationInvalidReturnTypeException (
157- resource_name ,
158- "Sql Transformation %s returned an invalid type: %s. Please either return a valid"
159- " sql string or Ibis / data frame expression from a dataset. If you want to return"
160- " data (data frames / arrow table), please yield those, not return."
161- % (name , type (transformation_result )),
162- )
163-
175+ # build model if needed
164176 sql_model = MaterializableSqlModel .from_relation (relation )
165-
166177 if not should_materialize :
167178 if meta :
168179 yield DataItemWithMeta (meta , sql_model )
@@ -188,6 +199,4 @@ def transformation_function(*args: Any, **kwargs: Any) -> Iterator[TDataItems]:
188199 section = section ,
189200 _impl_cls = DltTransformationResource ,
190201 _base_spec = TransformationConfiguration ,
191- )(
192- func if is_regular_resource else transformation_function # type: ignore[arg-type]
193- )
202+ )(transformation_function )
0 commit comments