Skip to content

Commit 320deed

Browse files
Dharin-shahevertlammerts
authored andcommitted
minor fixes
1 parent c42e3f8 commit 320deed

File tree

4 files changed

+44
-11
lines changed

4 files changed

+44
-11
lines changed

duckdb/experimental/spark/sql/column.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
if TYPE_CHECKING:
88
from ._typing import DateTimeLiteral, DecimalLiteral, LiteralType
99

10-
from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression
10+
from duckdb import ConstantExpression, Expression, FunctionExpression
1111
from duckdb.sqltypes import DuckDBPyType
1212

1313
__all__ = ["Column"]
@@ -173,9 +173,11 @@ def __getitem__(self, k: Any) -> "Column": # noqa: ANN401
173173
# raise ValueError("Using a slice with a step value is not supported")
174174
# return self.substr(k.start, k.stop)
175175
else:
176-
# TODO: this is super hacky # noqa: TD002, TD003
177-
expr_str = str(self.expr) + "." + str(k)
178-
return Column(ColumnExpression(expr_str))
176+
# Use struct_extract for proper struct field access
177+
from duckdb import ConstantExpression, FunctionExpression
178+
179+
field_name_expr = ConstantExpression(str(k))
180+
return Column(FunctionExpression("struct_extract", self.expr, field_name_expr))
179181

180182
def __getattr__(self, item: Any) -> "Column": # noqa: ANN401
181183
"""An expression that gets an item at position ``ordinal`` out of a list,

duckdb/experimental/spark/sql/functions.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,29 @@ def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column:
3030
return _invoke_function(name, *cols)
3131

3232

33+
def _nan_constant() -> Expression:
34+
"""Create a NaN constant expression.
35+
36+
Note: ConstantExpression(float("nan")) returns NULL instead of NaN because
37+
TransformPythonValue() in the C++ layer has nan_as_null=true by default.
38+
This is intentional for data import scenarios (CSV, Pandas, etc.) where NaN
39+
represents missing data.
40+
41+
For mathematical functions that need to return NaN (not NULL) for out-of-range
42+
inputs per PySpark/IEEE 754 semantics, we use SQLExpression as a workaround.
43+
44+
Returns:
45+
-------
46+
Expression
47+
An expression that evaluates to NaN (not NULL)
48+
49+
See Also:
50+
--------
51+
NAN_ROOT_CAUSE_ANALYSIS.md for full explanation
52+
"""
53+
return SQLExpression("'NaN'::DOUBLE")
54+
55+
3356
def col(column: str) -> Column: # noqa: D103
3457
return Column(ColumnExpression(column))
3558

@@ -617,11 +640,9 @@ def asin(col: "ColumnOrName") -> Column:
617640
+--------+
618641
"""
619642
col = _to_column_expr(col)
620-
# TODO: ConstantExpression(float("nan")) gives NULL and not NaN # noqa: TD002, TD003
643+
# asin domain is [-1, 1]; return NaN for out-of-range values per PySpark semantics
621644
return Column(
622-
CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise(
623-
FunctionExpression("asin", col)
624-
)
645+
CaseExpression((col < -1.0) | (col > 1.0), _nan_constant()).otherwise(FunctionExpression("asin", col))
625646
)
626647

627648

@@ -4177,7 +4198,11 @@ def acos(col: "ColumnOrName") -> Column:
41774198
| NaN|
41784199
+--------+
41794200
"""
4180-
return _invoke_function_over_columns("acos", col)
4201+
col = _to_column_expr(col)
4202+
# acos domain is [-1, 1]; return NaN for out-of-range values per PySpark semantics
4203+
return Column(
4204+
CaseExpression((col < -1.0) | (col > 1.0), _nan_constant()).otherwise(FunctionExpression("acos", col))
4205+
)
41814206

41824207

41834208
def call_function(funcName: str, *cols: "ColumnOrName") -> Column:

duckdb/experimental/spark/sql/readwriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def load( # noqa: D102
125125
types, names = schema.extract_types_and_names()
126126
df = df._cast_types(types)
127127
df = df.toDF(names)
128-
raise NotImplementedError
128+
return df
129129

130130
def csv( # noqa: D102
131131
self,

duckdb/experimental/spark/sql/type_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from duckdb.sqltypes import DuckDBPyType
44

5+
from ..exception import ContributionsAcceptedError
56
from .types import (
67
ArrayType,
78
BinaryType,
@@ -79,7 +80,12 @@ def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
7980
if id == "list" or id == "array":
8081
children = dtype.children
8182
return ArrayType(convert_type(children[0][1]))
82-
# TODO: add support for 'union' # noqa: TD002, TD003
83+
if id == "union":
84+
msg = (
85+
"Union types are not supported in the PySpark interface. "
86+
"DuckDB union types cannot be directly mapped to PySpark types."
87+
)
88+
raise ContributionsAcceptedError(msg)
8389
if id == "struct":
8490
children: list[tuple[str, DuckDBPyType]] = dtype.children
8591
fields = [StructField(x[0], convert_type(x[1])) for x in children]

0 commit comments

Comments
 (0)