1- import datetime # noqa: D100
1+ from __future__ import annotations # noqa: D100
2+
3+ import datetime
24import json
3- from collections . abc import Iterator
5+ import typing
46from decimal import Decimal
5- from typing import Optional
67
78import polars as pl
89from polars .io .plugins import register_io_source
910
1011import duckdb
11- from duckdb import SQLExpression
1212
13+ if typing .TYPE_CHECKING :
14+ from collections .abc import Iterator
15+
16+ import typing_extensions
17+
18+ _ExpressionTree : typing_extensions .TypeAlias = typing .Dict [str , typing .Union [str , int , "_ExpressionTree" , typing .Any ]] # noqa: UP006
1319
14- def _predicate_to_expression (predicate : pl .Expr ) -> Optional [SQLExpression ]:
20+
21+ def _predicate_to_expression (predicate : pl .Expr ) -> duckdb .Expression | None :
1522 """Convert a Polars predicate expression to a DuckDB-compatible SQL expression.
1623
1724 Parameters:
@@ -31,7 +38,7 @@ def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]:
3138 try :
3239 # Convert the tree to SQL
3340 sql_filter = _pl_tree_to_sql (tree )
34- return SQLExpression (sql_filter )
41+ return duckdb . SQLExpression (sql_filter )
3542 except Exception :
3643 # If the conversion fails, we return None
3744 return None
@@ -70,7 +77,7 @@ def _escape_sql_identifier(identifier: str) -> str:
7077 return f'"{ escaped } "'
7178
7279
73- def _pl_tree_to_sql (tree : dict ) -> str :
80+ def _pl_tree_to_sql (tree : _ExpressionTree ) -> str :
7481 """Recursively convert a Polars expression tree (as JSON) to a SQL string.
7582
7683 Parameters:
@@ -91,38 +98,51 @@ def _pl_tree_to_sql(tree: dict) -> str:
9198 Output: "(foo > 5)"
9299 """
93100 [node_type ] = tree .keys ()
94- subtree = tree [node_type ]
95101
96102 if node_type == "BinaryExpr" :
97103 # Binary expressions: left OP right
98- return (
99- "("
100- + " " .join (
101- (
102- _pl_tree_to_sql (subtree ["left" ]),
103- _pl_operation_to_sql (subtree ["op" ]),
104- _pl_tree_to_sql (subtree ["right" ]),
105- )
106- )
107- + ")"
108- )
104+ bin_expr_tree = tree [node_type ]
105+ assert isinstance (bin_expr_tree , dict ), f"A { node_type } should be a dict but got { type (bin_expr_tree )} "
106+ lhs , op , rhs = bin_expr_tree ["left" ], bin_expr_tree ["op" ], bin_expr_tree ["right" ]
107+ assert isinstance (lhs , dict ), f"LHS of a { node_type } should be a dict but got { type (lhs )} "
108+ assert isinstance (op , str ), f"The op of a { node_type } should be a str but got { type (op )} "
109+ assert isinstance (rhs , dict ), f"RHS of a { node_type } should be a dict but got { type (rhs )} "
110+ return f"({ _pl_tree_to_sql (lhs )} { _pl_operation_to_sql (op )} { _pl_tree_to_sql (rhs )} )"
109111 if node_type == "Column" :
110112 # A reference to a column name
111113 # Wrap in quotes to handle special characters
112- return _escape_sql_identifier (subtree )
114+ col_name = tree [node_type ]
115+ assert isinstance (col_name , str ), f"The col name of a { node_type } should be a str but got { type (col_name )} "
116+ return _escape_sql_identifier (col_name )
113117
114118 if node_type in ("Literal" , "Dyn" ):
115119 # Recursively process dynamic or literal values
116- return _pl_tree_to_sql (subtree )
120+ val_tree = tree [node_type ]
121+ assert isinstance (val_tree , dict ), f"A { node_type } should be a dict but got { type (val_tree )} "
122+ return _pl_tree_to_sql (val_tree )
117123
118124 if node_type == "Int" :
119125 # Direct integer literals
120- return str (subtree )
126+ int_literal = tree [node_type ]
127+ assert isinstance (int_literal , (int , str )), (
128+ f"The value of an Int should be an int or str but got { type (int_literal )} "
129+ )
130+ return str (int_literal )
121131
122132 if node_type == "Function" :
123133 # Handle boolean functions like IsNull, IsNotNull
124- inputs = subtree ["input" ]
125- func_dict = subtree ["function" ]
134+ func_tree = tree [node_type ]
135+ assert isinstance (func_tree , dict ), f"A { node_type } should be a dict but got { type (func_tree )} "
136+ inputs = func_tree ["input" ]
137+ assert isinstance (inputs , list ), f"A { node_type } should have a list of dicts as input but got { type (inputs )} "
138+ input_tree = inputs [0 ]
139+ assert isinstance (input_tree , dict ), (
140+ f"A { node_type } should have a list of dicts as input but got { type (input_tree )} "
141+ )
142+ func_dict = func_tree ["function" ]
143+ assert isinstance (func_dict , dict ), (
144+ f"A { node_type } should have a function dict as input but got { type (func_dict )} "
145+ )
126146
127147 if "Boolean" in func_dict :
128148 func = func_dict ["Boolean" ]
@@ -140,24 +160,31 @@ def _pl_tree_to_sql(tree: dict) -> str:
140160
141161 if node_type == "Scalar" :
142162 # Detect format: old style (dtype/value) or new style (direct type key)
143- if "dtype" in subtree and "value" in subtree :
144- dtype = str (subtree ["dtype" ])
145- value = subtree ["value" ]
163+ scalar_tree = tree [node_type ]
164+ assert isinstance (scalar_tree , dict ), f"A { node_type } should be a dict but got { type (scalar_tree )} "
165+ if "dtype" in scalar_tree and "value" in scalar_tree :
166+ dtype = str (scalar_tree ["dtype" ])
167+ value = scalar_tree ["value" ]
146168 else :
147169 # New style: dtype is the single key in the dict
148- dtype = next (iter (subtree .keys ()))
149- value = subtree
170+ dtype = next (iter (scalar_tree .keys ()))
171+ value = scalar_tree
172+ assert isinstance (dtype , str ), f"A { node_type } should have a str dtype but got { type (dtype )} "
173+ assert isinstance (value , dict ), f"A { node_type } should have a dict value but got { type (value )} "
150174
151175 # Decimal support
152176 if dtype .startswith ("{'Decimal'" ) or dtype == "Decimal" :
153177 decimal_value = value ["Decimal" ]
154- decimal_value = Decimal (decimal_value [0 ]) / Decimal (10 ** decimal_value [1 ])
155- return str (decimal_value )
178+ assert isinstance (decimal_value , list ), (
179+ f"A { dtype } should be a two member list but got { type (decimal_value )} "
180+ )
181+ return str (Decimal (decimal_value [0 ]) / Decimal (10 ** decimal_value [1 ]))
156182
157183 # Datetime with microseconds since epoch
158184 if dtype .startswith ("{'Datetime'" ) or dtype == "Datetime" :
159- micros = value ["Datetime" ][0 ]
160- dt_timestamp = datetime .datetime .fromtimestamp (micros / 1_000_000 , tz = datetime .UTC )
185+ micros = value ["Datetime" ]
186+ assert isinstance (micros , list ), f"A { dtype } should be a one member list but got { type (micros )} "
187+ dt_timestamp = datetime .datetime .fromtimestamp (micros [0 ] / 1_000_000 , tz = datetime .timezone .utc )
161188 return f"'{ dt_timestamp !s} '::TIMESTAMP"
162189
163190 # Match simple numeric/boolean types
@@ -179,6 +206,7 @@ def _pl_tree_to_sql(tree: dict) -> str:
179206 # Time type
180207 if dtype == "Time" :
181208 nanoseconds = value ["Time" ]
209+ assert isinstance (nanoseconds , int ), f"A { dtype } should be an int but got { type (nanoseconds )} "
182210 seconds = nanoseconds // 1_000_000_000
183211 microseconds = (nanoseconds % 1_000_000_000 ) // 1_000
184212 dt_time = (datetime .datetime .min + datetime .timedelta (seconds = seconds , microseconds = microseconds )).time ()
@@ -187,36 +215,41 @@ def _pl_tree_to_sql(tree: dict) -> str:
187215 # Date type
188216 if dtype == "Date" :
189217 days_since_epoch = value ["Date" ]
218+ assert isinstance (days_since_epoch , (float , int )), (
219+ f"A { dtype } should be a number but got { type (days_since_epoch )} "
220+ )
190221 date = datetime .date (1970 , 1 , 1 ) + datetime .timedelta (days = days_since_epoch )
191222 return f"'{ date } '::DATE"
192223
193224 # Binary type
194225 if dtype == "Binary" :
195- binary_data = bytes (value ["Binary" ])
226+ bin_value = value ["Binary" ]
227+ assert isinstance (bin_value , bytes ), f"A { dtype } should be bytes but got { type (bin_value )} "
228+ binary_data = bytes (bin_value )
196229 escaped = "" .join (f"\\ x{ b :02x} " for b in binary_data )
197230 return f"'{ escaped } '::BLOB"
198231
199232 # String type
200233 if dtype == "String" or dtype == "StringOwned" :
201234 # Some new formats may store directly under StringOwned
202- string_val = value .get ("StringOwned" , value .get ("String" , None ))
235+ string_val : object | None = value .get ("StringOwned" , value .get ("String" , None ))
203236 return f"'{ string_val } '"
204237
205238 msg = f"Unsupported scalar type { dtype !s} , with value { value } "
206239 raise NotImplementedError (msg )
207240
208- msg = f"Node type: { node_type } is not implemented. { subtree } "
241+ msg = f"Node type: { node_type } is not implemented. { tree [ node_type ] } "
209242 raise NotImplementedError (msg )
210243
211244
212245def duckdb_source (relation : duckdb .DuckDBPyRelation , schema : pl .schema .Schema ) -> pl .LazyFrame :
213246 """A polars IO plugin for DuckDB."""
214247
215248 def source_generator (
216- with_columns : Optional [ list [str ]] ,
217- predicate : Optional [ pl .Expr ] ,
218- n_rows : Optional [ int ] ,
219- batch_size : Optional [ int ] ,
249+ with_columns : list [str ] | None ,
250+ predicate : pl .Expr | None ,
251+ n_rows : int | None ,
252+ batch_size : int | None ,
220253 ) -> Iterator [pl .DataFrame ]:
221254 duck_predicate = None
222255 relation_final = relation
@@ -239,8 +272,8 @@ def source_generator(
239272 for record_batch in iter (results .read_next_batch , None ):
240273 if predicate is not None and duck_predicate is None :
241274 # We have a predicate, but did not manage to push it down, we fallback here
242- yield pl .from_arrow (record_batch ).filter (predicate )
275+ yield pl .from_arrow (record_batch ).filter (predicate ) # type: ignore[arg-type,misc]
243276 else :
244- yield pl .from_arrow (record_batch )
277+ yield pl .from_arrow (record_batch ) # type: ignore[misc]
245278
246279 return register_io_source (source_generator , schema = schema )
0 commit comments