Skip to content

Commit 84b3708

Browse files
Dharin-shahevertlammerts
authored andcommitted
add treeString support for printSchema
1 parent 748031d commit 84b3708

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

duckdb/experimental/spark/sql/types.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,72 @@ def fieldNames(self) -> list[str]:
894894
"""
895895
return list(self.names)
896896

897+
def treeString(self, level: Optional[int] = None) -> str:
898+
"""Returns a string representation of the schema in tree format.
899+
900+
Parameters
901+
----------
902+
level : int, optional
903+
Maximum depth to print. If None, prints all levels.
904+
905+
Returns:
906+
-------
907+
str
908+
Tree-formatted schema string
909+
910+
Examples:
911+
--------
912+
>>> schema = StructType([StructField("age", IntegerType(), True)])
913+
>>> print(schema.treeString())
914+
root
915+
|-- age: integer (nullable = true)
916+
"""
917+
def _tree_string(schema: "StructType", depth: int = 0, max_depth: Optional[int] = None) -> list[str]:
918+
"""Recursively build tree string lines."""
919+
lines = []
920+
if depth == 0:
921+
lines.append("root")
922+
923+
if max_depth is not None and depth >= max_depth:
924+
return lines
925+
926+
for field in schema.fields:
927+
indent = " " * depth
928+
prefix = " |-- "
929+
nullable_str = "true" if field.nullable else "false"
930+
931+
# Handle nested StructType
932+
if isinstance(field.dataType, StructType):
933+
lines.append(f"{indent}{prefix}{field.name}: struct (nullable = {nullable_str})")
934+
# Recursively handle nested struct - don't skip any lines, root only appears at depth 0
935+
nested_lines = _tree_string(field.dataType, depth + 1, max_depth)
936+
lines.extend(nested_lines)
937+
# Handle ArrayType
938+
elif isinstance(field.dataType, ArrayType):
939+
element_type = field.dataType.elementType
940+
if isinstance(element_type, StructType):
941+
lines.append(f"{indent}{prefix}{field.name}: array (nullable = {nullable_str})")
942+
lines.append(f"{indent} | |-- element: struct (containsNull = {field.dataType.containsNull})")
943+
nested_lines = _tree_string(element_type, depth + 2, max_depth)
944+
lines.extend(nested_lines)
945+
else:
946+
type_str = element_type.simpleString()
947+
lines.append(f"{indent}{prefix}{field.name}: array<{type_str}> (nullable = {nullable_str})")
948+
# Handle MapType
949+
elif isinstance(field.dataType, MapType):
950+
key_type = field.dataType.keyType.simpleString()
951+
value_type = field.dataType.valueType.simpleString()
952+
lines.append(f"{indent}{prefix}{field.name}: map<{key_type},{value_type}> (nullable = {nullable_str})")
953+
# Handle simple types
954+
else:
955+
type_str = field.dataType.simpleString()
956+
lines.append(f"{indent}{prefix}{field.name}: {type_str} (nullable = {nullable_str})")
957+
958+
return lines
959+
960+
lines = _tree_string(self, 0, level)
961+
return "\n".join(lines)
962+
897963
def needConversion(self) -> bool: # noqa: D102
898964
# We need convert Row()/namedtuple into tuple()
899965
return True

tests/fast/spark/test_spark_dataframe.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,3 +508,76 @@ def test_printSchema_negative_level(self, spark):
508508

509509
with pytest.raises(PySparkValueError):
510510
df.printSchema(level=-1)
511+
512+
def test_treeString_basic(self, spark):
513+
data = [("Alice", 25, 5000)]
514+
df = spark.createDataFrame(data, ["name", "age", "salary"])
515+
tree = df.schema.treeString()
516+
517+
assert tree.startswith("root\n")
518+
assert " |-- name:" in tree
519+
assert " |-- age:" in tree
520+
assert " |-- salary:" in tree
521+
assert "(nullable = true)" in tree
522+
assert tree.count(" |-- ") == 3
523+
524+
def test_treeString_nested_struct(self, spark):
525+
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType
526+
527+
schema = StructType([
528+
StructField("id", IntegerType(), True),
529+
StructField("person", StructType([
530+
StructField("name", StringType(), True),
531+
StructField("age", IntegerType(), True)
532+
]), True)
533+
])
534+
data = [(1, {"name": "Alice", "age": 25})]
535+
df = spark.createDataFrame(data, schema)
536+
tree = df.schema.treeString()
537+
538+
assert "root\n" in tree
539+
assert " |-- id:" in tree
540+
assert " |-- person: struct (nullable = true)" in tree
541+
assert "name:" in tree
542+
assert "age:" in tree
543+
544+
def test_treeString_with_level(self, spark):
545+
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType
546+
547+
schema = StructType([
548+
StructField("id", IntegerType(), True),
549+
StructField("person", StructType([
550+
StructField("name", StringType(), True),
551+
StructField("details", StructType([
552+
StructField("address", StringType(), True)
553+
]), True)
554+
]), True)
555+
])
556+
557+
data = [(1, {"name": "Alice", "details": {"address": "123 Main St"}})]
558+
df = spark.createDataFrame(data, schema)
559+
560+
# Level 1 should only show top-level fields
561+
tree_level_1 = df.schema.treeString(level=1)
562+
assert " |-- id:" in tree_level_1
563+
assert " |-- person: struct" in tree_level_1
564+
# Should not show nested field names at level 1
565+
lines = tree_level_1.split('\n')
566+
assert len([l for l in lines if l.strip()]) <= 3
567+
568+
def test_treeString_array_type(self, spark):
569+
from spark_namespace.sql.types import ArrayType, StringType, StructField, StructType
570+
571+
schema = StructType([
572+
StructField("name", StringType(), True),
573+
StructField("hobbies", ArrayType(StringType()), True)
574+
])
575+
576+
data = [("Alice", ["reading", "coding"])]
577+
df = spark.createDataFrame(data, schema)
578+
tree = df.schema.treeString()
579+
580+
assert "root\n" in tree
581+
assert " |-- name:" in tree
582+
assert " |-- hobbies: array<" in tree
583+
assert "(nullable = true)" in tree

0 commit comments

Comments
 (0)