Skip to content

Commit b32e901

Browse files
Dharin-shahevertlammerts
authored andcommitted
[minor] add suport for types and print schema
1 parent f2b5da9 commit b32e901

File tree

2 files changed

+128
-2
lines changed

2 files changed

+128
-2
lines changed

duckdb/experimental/spark/sql/dataframe.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,22 @@ def columns(self) -> list[str]:
569569
"""
570570
return [f.name for f in self.schema.fields]
571571

572+
@property
573+
def dtypes(self) -> list[tuple[str, str]]:
574+
"""Returns all column names and their data types as a list of tuples.
575+
576+
Returns:
577+
-------
578+
list of tuple
579+
List of tuples, each tuple containing a column name and its data type as strings.
580+
581+
Examples:
582+
--------
583+
>>> df.dtypes
584+
[('age', 'bigint'), ('name', 'string')]
585+
"""
586+
return [(f.name, f.dataType.simpleString()) for f in self.schema.fields]
587+
572588
def _ipython_key_completions_(self) -> list[str]:
573589
# Provides tab-completion for column names in PySpark DataFrame
574590
# when accessed in bracket notation, e.g. df['<TAB>]
@@ -982,8 +998,27 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
982998
def write(self) -> DataFrameWriter: # noqa: D102
983999
return DataFrameWriter(self)
9841000

985-
def printSchema(self) -> None: # noqa: D102
986-
raise ContributionsAcceptedError
1001+
def printSchema(self, level: Optional[int] = None) -> None:
1002+
"""Prints out the schema in the tree format.
1003+
1004+
Parameters
1005+
----------
1006+
level : int, optional
1007+
How many levels to print for nested schemas. Prints all levels by default.
1008+
1009+
Examples:
1010+
--------
1011+
>>> df.printSchema()
1012+
root
1013+
|-- age: bigint (nullable = true)
1014+
|-- name: string (nullable = true)
1015+
"""
1016+
if level is not None and level < 0:
1017+
raise PySparkValueError(
1018+
error_class="NEGATIVE_VALUE",
1019+
message_parameters={"arg_name": "level", "arg_value": str(level)},
1020+
)
1021+
print(self.schema.treeString(level))
9871022

9881023
def union(self, other: "DataFrame") -> "DataFrame":
9891024
"""Return a new :class:`DataFrame` containing union of rows in this and another

tests/fast/spark/test_spark_dataframe.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,94 @@ def test_cache(self, spark):
427427
assert df is not cached
428428
assert cached.collect() == df.collect()
429429
assert cached.collect() == [Row(one=1, two=2, three=3, four=4)]
430+
431+
def test_dtypes(self, spark):
432+
data = [("Alice", 25, 5000.0), ("Bob", 30, 6000.0)]
433+
df = spark.createDataFrame(data, ["name", "age", "salary"])
434+
dtypes = df.dtypes
435+
assert isinstance(dtypes, list)
436+
assert len(dtypes) == 3
437+
for col_name, col_type in dtypes:
438+
assert isinstance(col_name, str)
439+
assert isinstance(col_type, str)
440+
col_names = [name for name, _ in dtypes]
441+
assert col_names == ["name", "age", "salary"]
442+
for _, col_type in dtypes:
443+
assert len(col_type) > 0 # Should have some type string
444+
445+
def test_dtypes_complex_types(self, spark):
446+
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType
447+
448+
schema = StructType([
449+
StructField("name", StringType(), True),
450+
StructField("scores", ArrayType(IntegerType()), True),
451+
StructField("address", StructType([
452+
StructField("city", StringType(), True),
453+
StructField("zip", StringType(), True)
454+
]), True)
455+
])
456+
457+
data = [
458+
("Alice", [90, 85, 88], {"city": "NYC", "zip": "10001"}),
459+
("Bob", [75, 80, 82], {"city": "LA", "zip": "90001"})
460+
]
461+
462+
df = spark.createDataFrame(data, schema)
463+
dtypes = df.dtypes
464+
465+
assert len(dtypes) == 3
466+
col_names = [name for name, _ in dtypes]
467+
assert col_names == ["name", "scores", "address"]
468+
469+
def test_printSchema(self, spark, capsys):
470+
data = [("Alice", 25, 5000), ("Bob", 30, 6000)]
471+
df = spark.createDataFrame(data, ["name", "age", "salary"])
472+
df.printSchema()
473+
captured = capsys.readouterr()
474+
output = captured.out
475+
assert "root" in output
476+
assert "name" in output
477+
assert "age" in output
478+
assert "salary" in output
479+
assert "string" in output or "varchar" in output.lower()
480+
assert "int" in output.lower() or "bigint" in output.lower()
481+
482+
def test_printSchema_nested(self, spark, capsys):
483+
# Test printSchema with nested schema
484+
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType
485+
486+
schema = StructType([
487+
StructField("id", IntegerType(), True),
488+
StructField("person", StructType([
489+
StructField("name", StringType(), True),
490+
StructField("age", IntegerType(), True)
491+
]), True),
492+
StructField("hobbies", ArrayType(StringType()), True)
493+
])
494+
495+
data = [
496+
(1, {"name": "Alice", "age": 25}, ["reading", "coding"]),
497+
(2, {"name": "Bob", "age": 30}, ["gaming", "music"])
498+
]
499+
500+
df = spark.createDataFrame(data, schema)
501+
502+
# Should not raise an error
503+
df.printSchema()
504+
505+
captured = capsys.readouterr()
506+
output = captured.out
507+
508+
# Verify nested structure is shown
509+
assert "root" in output
510+
assert "person" in output
511+
assert "hobbies" in output
512+
513+
def test_printSchema_negative_level(self, spark):
514+
# Test printSchema with invalid level parameter
515+
data = [("Alice", 25)]
516+
df = spark.createDataFrame(data, ["name", "age"])
517+
518+
# Should raise PySparkValueError for negative level
519+
with pytest.raises(PySparkValueError):
520+
df.printSchema(level=-1)

0 commit comments

Comments
 (0)