Skip to content

Commit 030a76a

Browse files
[minor][spark] Minor bugfixes (#180)
This PR fixes 3 small bugs: 1. When you call asin(2.0) or acos(5.0) (values outside the valid range), PySpark is supposed to return NaN. DuckDB was returning NULL instead, which breaks code that expects PySpark behavior. Fixed by using a SQL workaround since the internal C++ layer converts NaN to NULL by design. 2. The spark.read.json() method had all the code written, but a raise NotImplementedError at the end blocked it from ever returning. One line removal fixed it. 3. Instead of hitting a generic NotImplementedError when encountering DuckDB union types (which PySpark doesn't support), it now throws a clear ContributionsAcceptedError explaining the limitation.
2 parents bf7b2a0 + 743fdb2 commit 030a76a

File tree

9 files changed

+324
-19
lines changed

9 files changed

+324
-19
lines changed

duckdb/experimental/spark/sql/dataframe.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from duckdb import ColumnExpression, Expression, StarExpression
1616

1717
from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError
18-
from ..exception import ContributionsAcceptedError
1918
from .column import Column
2019
from .readwriter import DataFrameWriter
2120
from .type_utils import duckdb_to_spark_schema
@@ -569,6 +568,22 @@ def columns(self) -> list[str]:
569568
"""
570569
return [f.name for f in self.schema.fields]
571570

571+
@property
572+
def dtypes(self) -> list[tuple[str, str]]:
573+
"""Returns all column names and their data types as a list of tuples.
574+
575+
Returns:
576+
-------
577+
list of tuple
578+
List of tuples, each tuple containing a column name and its data type as strings.
579+
580+
Examples:
581+
--------
582+
>>> df.dtypes
583+
[('age', 'bigint'), ('name', 'string')]
584+
"""
585+
return [(f.name, f.dataType.simpleString()) for f in self.schema.fields]
586+
572587
def _ipython_key_completions_(self) -> list[str]:
573588
# Provides tab-completion for column names in PySpark DataFrame
574589
# when accessed in bracket notation, e.g. df['<TAB>]
@@ -982,8 +997,27 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
982997
def write(self) -> DataFrameWriter: # noqa: D102
983998
return DataFrameWriter(self)
984999

985-
def printSchema(self) -> None: # noqa: D102
986-
raise ContributionsAcceptedError
1000+
def printSchema(self, level: Optional[int] = None) -> None:
1001+
"""Prints out the schema in the tree format.
1002+
1003+
Parameters
1004+
----------
1005+
level : int, optional
1006+
How many levels to print for nested schemas. Prints all levels by default.
1007+
1008+
Examples:
1009+
--------
1010+
>>> df.printSchema()
1011+
root
1012+
|-- age: bigint (nullable = true)
1013+
|-- name: string (nullable = true)
1014+
"""
1015+
if level is not None and level < 0:
1016+
raise PySparkValueError(
1017+
error_class="NEGATIVE_VALUE",
1018+
message_parameters={"arg_name": "level", "arg_value": str(level)},
1019+
)
1020+
print(self.schema.treeString(level))
9871021

9881022
def union(self, other: "DataFrame") -> "DataFrame":
9891023
"""Return a new :class:`DataFrame` containing union of rows in this and another

duckdb/experimental/spark/sql/functions.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,25 @@ def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column:
3030
return _invoke_function(name, *cols)
3131

3232

33+
def _nan_constant() -> Expression:
34+
"""Create a NaN constant expression.
35+
36+
Note: ConstantExpression(float("nan")) returns NULL instead of NaN because
37+
TransformPythonValue() in the C++ layer has nan_as_null=true by default.
38+
This is intentional for data import scenarios (CSV, Pandas, etc.) where NaN
39+
represents missing data.
40+
41+
For mathematical functions that need to return NaN (not NULL) for out-of-range
42+
inputs per PySpark/IEEE 754 semantics, we use SQLExpression as a workaround.
43+
44+
Returns:
45+
-------
46+
Expression
47+
An expression that evaluates to NaN (not NULL)
48+
"""
49+
return SQLExpression("'NaN'::DOUBLE")
50+
51+
3352
def col(column: str) -> Column: # noqa: D103
3453
return Column(ColumnExpression(column))
3554

@@ -617,11 +636,9 @@ def asin(col: "ColumnOrName") -> Column:
617636
+--------+
618637
"""
619638
col = _to_column_expr(col)
620-
# TODO: ConstantExpression(float("nan")) gives NULL and not NaN # noqa: TD002, TD003
639+
# asin domain is [-1, 1]; return NaN for out-of-range values per PySpark semantics
621640
return Column(
622-
CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise(
623-
FunctionExpression("asin", col)
624-
)
641+
CaseExpression((col < -1.0) | (col > 1.0), _nan_constant()).otherwise(FunctionExpression("asin", col))
625642
)
626643

627644

@@ -4177,7 +4194,11 @@ def acos(col: "ColumnOrName") -> Column:
41774194
| NaN|
41784195
+--------+
41794196
"""
4180-
return _invoke_function_over_columns("acos", col)
4197+
col = _to_column_expr(col)
4198+
# acos domain is [-1, 1]; return NaN for out-of-range values per PySpark semantics
4199+
return Column(
4200+
CaseExpression((col < -1.0) | (col > 1.0), _nan_constant()).otherwise(FunctionExpression("acos", col))
4201+
)
41814202

41824203

41834204
def call_function(funcName: str, *cols: "ColumnOrName") -> Column:

duckdb/experimental/spark/sql/readwriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def load( # noqa: D102
125125
types, names = schema.extract_types_and_names()
126126
df = df._cast_types(types)
127127
df = df.toDF(names)
128-
raise NotImplementedError
128+
return df
129129

130130
def csv( # noqa: D102
131131
self,

duckdb/experimental/spark/sql/type_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from duckdb.sqltypes import DuckDBPyType
44

5+
from ..exception import ContributionsAcceptedError
56
from .types import (
67
ArrayType,
78
BinaryType,
@@ -79,7 +80,12 @@ def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
7980
if id == "list" or id == "array":
8081
children = dtype.children
8182
return ArrayType(convert_type(children[0][1]))
82-
# TODO: add support for 'union' # noqa: TD002, TD003
83+
if id == "union":
84+
msg = (
85+
"Union types are not supported in the PySpark interface. "
86+
"DuckDB union types cannot be directly mapped to PySpark types."
87+
)
88+
raise ContributionsAcceptedError(msg)
8389
if id == "struct":
8490
children: list[tuple[str, DuckDBPyType]] = dtype.children
8591
fields = [StructField(x[0], convert_type(x[1])) for x in children]

duckdb/experimental/spark/sql/types.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,77 @@ def fieldNames(self) -> list[str]:
894894
"""
895895
return list(self.names)
896896

897+
def treeString(self, level: Optional[int] = None) -> str:
898+
"""Returns a string representation of the schema in tree format.
899+
900+
Parameters
901+
----------
902+
level : int, optional
903+
Maximum depth to print. If None, prints all levels.
904+
905+
Returns:
906+
-------
907+
str
908+
Tree-formatted schema string
909+
910+
Examples:
911+
--------
912+
>>> schema = StructType([StructField("age", IntegerType(), True)])
913+
>>> print(schema.treeString())
914+
root
915+
|-- age: integer (nullable = true)
916+
"""
917+
918+
def _tree_string(schema: "StructType", depth: int = 0, max_depth: Optional[int] = None) -> list[str]:
919+
"""Recursively build tree string lines."""
920+
lines = []
921+
if depth == 0:
922+
lines.append("root")
923+
924+
if max_depth is not None and depth >= max_depth:
925+
return lines
926+
927+
for field in schema.fields:
928+
indent = " " * depth
929+
prefix = " |-- "
930+
nullable_str = "true" if field.nullable else "false"
931+
932+
# Handle nested StructType
933+
if isinstance(field.dataType, StructType):
934+
lines.append(f"{indent}{prefix}{field.name}: struct (nullable = {nullable_str})")
935+
# Recursively handle nested struct - don't skip any lines, root only appears at depth 0
936+
nested_lines = _tree_string(field.dataType, depth + 1, max_depth)
937+
lines.extend(nested_lines)
938+
# Handle ArrayType
939+
elif isinstance(field.dataType, ArrayType):
940+
element_type = field.dataType.elementType
941+
if isinstance(element_type, StructType):
942+
lines.append(f"{indent}{prefix}{field.name}: array (nullable = {nullable_str})")
943+
lines.append(
944+
f"{indent} | |-- element: struct (containsNull = {field.dataType.containsNull})"
945+
)
946+
nested_lines = _tree_string(element_type, depth + 2, max_depth)
947+
lines.extend(nested_lines)
948+
else:
949+
type_str = element_type.simpleString()
950+
lines.append(f"{indent}{prefix}{field.name}: array<{type_str}> (nullable = {nullable_str})")
951+
# Handle MapType
952+
elif isinstance(field.dataType, MapType):
953+
key_type = field.dataType.keyType.simpleString()
954+
value_type = field.dataType.valueType.simpleString()
955+
lines.append(
956+
f"{indent}{prefix}{field.name}: map<{key_type},{value_type}> (nullable = {nullable_str})"
957+
)
958+
# Handle simple types
959+
else:
960+
type_str = field.dataType.simpleString()
961+
lines.append(f"{indent}{prefix}{field.name}: {type_str} (nullable = {nullable_str})")
962+
963+
return lines
964+
965+
lines = _tree_string(self, 0, level)
966+
return "\n".join(lines)
967+
897968
def needConversion(self) -> bool: # noqa: D102
898969
# We need convert Row()/namedtuple into tuple()
899970
return True

tests/fast/spark/test_spark_dataframe.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,173 @@ def test_cache(self, spark):
427427
assert df is not cached
428428
assert cached.collect() == df.collect()
429429
assert cached.collect() == [Row(one=1, two=2, three=3, four=4)]
430+
431+
def test_dtypes(self, spark):
432+
data = [("Alice", 25, 5000.0), ("Bob", 30, 6000.0)]
433+
df = spark.createDataFrame(data, ["name", "age", "salary"])
434+
dtypes = df.dtypes
435+
436+
assert isinstance(dtypes, list)
437+
assert len(dtypes) == 3
438+
for col_name, col_type in dtypes:
439+
assert isinstance(col_name, str)
440+
assert isinstance(col_type, str)
441+
442+
col_names = [name for name, _ in dtypes]
443+
assert col_names == ["name", "age", "salary"]
444+
for _, col_type in dtypes:
445+
assert len(col_type) > 0
446+
447+
def test_dtypes_complex_types(self, spark):
448+
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType
449+
450+
schema = StructType(
451+
[
452+
StructField("name", StringType(), True),
453+
StructField("scores", ArrayType(IntegerType()), True),
454+
StructField(
455+
"address",
456+
StructType([StructField("city", StringType(), True), StructField("zip", StringType(), True)]),
457+
True,
458+
),
459+
]
460+
)
461+
data = [
462+
("Alice", [90, 85, 88], {"city": "NYC", "zip": "10001"}),
463+
("Bob", [75, 80, 82], {"city": "LA", "zip": "90001"}),
464+
]
465+
df = spark.createDataFrame(data, schema)
466+
dtypes = df.dtypes
467+
468+
assert len(dtypes) == 3
469+
col_names = [name for name, _ in dtypes]
470+
assert col_names == ["name", "scores", "address"]
471+
472+
def test_printSchema(self, spark, capsys):
473+
data = [("Alice", 25, 5000), ("Bob", 30, 6000)]
474+
df = spark.createDataFrame(data, ["name", "age", "salary"])
475+
df.printSchema()
476+
captured = capsys.readouterr()
477+
output = captured.out
478+
479+
assert "root" in output
480+
assert "name" in output
481+
assert "age" in output
482+
assert "salary" in output
483+
assert "string" in output or "varchar" in output.lower()
484+
assert "int" in output.lower() or "bigint" in output.lower()
485+
486+
def test_printSchema_nested(self, spark, capsys):
487+
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType
488+
489+
schema = StructType(
490+
[
491+
StructField("id", IntegerType(), True),
492+
StructField(
493+
"person",
494+
StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]),
495+
True,
496+
),
497+
StructField("hobbies", ArrayType(StringType()), True),
498+
]
499+
)
500+
data = [
501+
(1, {"name": "Alice", "age": 25}, ["reading", "coding"]),
502+
(2, {"name": "Bob", "age": 30}, ["gaming", "music"]),
503+
]
504+
df = spark.createDataFrame(data, schema)
505+
df.printSchema()
506+
captured = capsys.readouterr()
507+
output = captured.out
508+
509+
assert "root" in output
510+
assert "person" in output
511+
assert "hobbies" in output
512+
513+
def test_printSchema_negative_level(self, spark):
514+
data = [("Alice", 25)]
515+
df = spark.createDataFrame(data, ["name", "age"])
516+
517+
with pytest.raises(PySparkValueError):
518+
df.printSchema(level=-1)
519+
520+
def test_treeString_basic(self, spark):
521+
data = [("Alice", 25, 5000)]
522+
df = spark.createDataFrame(data, ["name", "age", "salary"])
523+
tree = df.schema.treeString()
524+
525+
assert tree.startswith("root\n")
526+
assert " |-- name:" in tree
527+
assert " |-- age:" in tree
528+
assert " |-- salary:" in tree
529+
assert "(nullable = true)" in tree
530+
assert tree.count(" |-- ") == 3
531+
532+
def test_treeString_nested_struct(self, spark):
533+
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType
534+
535+
schema = StructType(
536+
[
537+
StructField("id", IntegerType(), True),
538+
StructField(
539+
"person",
540+
StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]),
541+
True,
542+
),
543+
]
544+
)
545+
data = [(1, {"name": "Alice", "age": 25})]
546+
df = spark.createDataFrame(data, schema)
547+
tree = df.schema.treeString()
548+
549+
assert "root\n" in tree
550+
assert " |-- id:" in tree
551+
assert " |-- person: struct (nullable = true)" in tree
552+
assert "name:" in tree
553+
assert "age:" in tree
554+
555+
def test_treeString_with_level(self, spark):
556+
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType
557+
558+
schema = StructType(
559+
[
560+
StructField("id", IntegerType(), True),
561+
StructField(
562+
"person",
563+
StructType(
564+
[
565+
StructField("name", StringType(), True),
566+
StructField("details", StructType([StructField("address", StringType(), True)]), True),
567+
]
568+
),
569+
True,
570+
),
571+
]
572+
)
573+
574+
data = [(1, {"name": "Alice", "details": {"address": "123 Main St"}})]
575+
df = spark.createDataFrame(data, schema)
576+
577+
# Level 1 should only show top-level fields
578+
tree_level_1 = df.schema.treeString(level=1)
579+
assert " |-- id:" in tree_level_1
580+
assert " |-- person: struct" in tree_level_1
581+
# Should not show nested field names at level 1
582+
lines = tree_level_1.split("\n")
583+
assert len([line for line in lines if line.strip()]) <= 3
584+
585+
def test_treeString_array_type(self, spark):
586+
from spark_namespace.sql.types import ArrayType, StringType, StructField, StructType
587+
588+
schema = StructType(
589+
[StructField("name", StringType(), True), StructField("hobbies", ArrayType(StringType()), True)]
590+
)
591+
592+
data = [("Alice", ["reading", "coding"])]
593+
df = spark.createDataFrame(data, schema)
594+
tree = df.schema.treeString()
595+
596+
assert "root\n" in tree
597+
assert " |-- name:" in tree
598+
assert " |-- hobbies: array<" in tree
599+
assert "(nullable = true)" in tree

0 commit comments

Comments
 (0)