Skip to content

Commit fa17888

Browse files
committed
Add option to ignore metadata when comparing DataFrames
1 parent 70f0c6a commit fa17888

File tree

5 files changed

+104
-14
lines changed

5 files changed

+104
-14
lines changed

chispa/dataframe_comparer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class DataFramesNotEqualError(Exception):
1010

1111

1212
def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False,
13-
ignore_column_order=False, ignore_row_order=False, underline_cells=False):
13+
ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False):
1414
if transforms is None:
1515
transforms = []
1616
if ignore_column_order:
@@ -19,7 +19,7 @@ def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_n
1919
transforms.append(lambda df: df.sort(df.columns))
2020
df1 = reduce(lambda acc, fn: fn(acc), transforms, df1)
2121
df2 = reduce(lambda acc, fn: fn(acc), transforms, df2)
22-
assert_schema_equality(df1.schema, df2.schema, ignore_nullable)
22+
assert_schema_equality(df1.schema, df2.schema, ignore_nullable, ignore_metadata)
2323
if allow_nan_equality:
2424
assert_generic_rows_equality(
2525
df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], underline_cells=underline_cells)

chispa/schema_comparer.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,37 @@ class SchemasNotEqualError(Exception):
88
pass
99

1010

11-
def assert_schema_equality(s1, s2, ignore_nullable=False):
12-
if ignore_nullable:
13-
assert_schema_equality_ignore_nullable(s1, s2)
14-
else:
11+
def assert_schema_equality(s1, s2, ignore_nullable=False, ignore_metadata=False):
12+
if not ignore_nullable and not ignore_metadata:
1513
assert_basic_schema_equality(s1, s2)
14+
else:
15+
assert_schema_equality_full(s1, s2, ignore_nullable, ignore_metadata)
16+
1617

18+
def assert_schema_equality_full(s1, s2, ignore_nullable=False, ignore_metadata=False):
19+
def inner(s1, s2, ignore_nullable, ignore_metadata):
20+
if len(s1) != len(s2):
21+
return False
22+
zipped = list(six.moves.zip_longest(s1, s2))
23+
for sf1, sf2 in zipped:
24+
if not are_structfields_equal(sf1, sf2, ignore_nullable, ignore_metadata):
25+
return False
26+
return True
1727

28+
if not inner(s1, s2, ignore_nullable, ignore_metadata):
29+
t = PrettyTable(["schema1", "schema2"])
30+
zipped = list(six.moves.zip_longest(s1, s2))
31+
for sf1, sf2 in zipped:
32+
if are_structfields_equal(sf1, sf2, True):
33+
t.add_row([blue(sf1), blue(sf2)])
34+
else:
35+
t.add_row([sf1, sf2])
36+
raise SchemasNotEqualError("\n" + t.get_string())
37+
38+
39+
# deprecate this
40+
# perhaps it is a little faster, but do we really need this?
41+
# I think schema equality operations are really fast to begin with
1842
def assert_basic_schema_equality(s1, s2):
1943
if s1 != s2:
2044
t = PrettyTable(["schema1", "schema2"])
@@ -27,8 +51,10 @@ def assert_basic_schema_equality(s1, s2):
2751
raise SchemasNotEqualError("\n" + t.get_string())
2852

2953

54+
55+
# deprecate this. ignore_nullable should be a flag.
3056
def assert_schema_equality_ignore_nullable(s1, s2):
31-
if are_schemas_equal_ignore_nullable(s1, s2) == False:
57+
if not are_schemas_equal_ignore_nullable(s1, s2):
3258
t = PrettyTable(["schema1", "schema2"])
3359
zipped = list(six.moves.zip_longest(s1, s2))
3460
for sf1, sf2 in zipped:
@@ -39,6 +65,7 @@ def assert_schema_equality_ignore_nullable(s1, s2):
3965
raise SchemasNotEqualError("\n" + t.get_string())
4066

4167

68+
# deprecate this. ignore_nullable should be a flag.
4269
def are_schemas_equal_ignore_nullable(s1, s2):
4370
if len(s1) != len(s2):
4471
return False
@@ -49,21 +76,25 @@ def are_schemas_equal_ignore_nullable(s1, s2):
4976
return True
5077

5178

52-
def are_structfields_equal(sf1, sf2, ignore_nullability=False):
53-
if ignore_nullability:
79+
# "ignore_nullability" should be "ignore_nullable" for consistent terminology
80+
def are_structfields_equal(sf1, sf2, ignore_nullability=False, ignore_metadata=False):
81+
if not ignore_nullability and not ignore_metadata:
82+
return sf1 == sf2
83+
else:
5484
if sf1 is None or sf2 is None:
5585
if sf1 is None and sf2 is None:
5686
return True
5787
else:
5888
return False
5989
if sf1.name != sf2.name:
6090
return False
91+
if not ignore_metadata and sf1.metadata != sf2.metadata:
92+
return False
6193
else:
6294
return are_datatypes_equal_ignore_nullable(sf1.dataType, sf2.dataType)
63-
else:
64-
return sf1 == sf2
6595

6696

97+
# deprecate this
6798
def are_datatypes_equal_ignore_nullable(dt1, dt2):
6899
"""Checks if datatypes are equal, descending into structs and arrays to
69100
ignore nullability.

tests/test_column_comparer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from .spark import *
3+
from .spark import spark
44
from chispa import *
55

66

tests/test_dataframe_comparer.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from chispa.dataframe_comparer import are_dfs_equal
66
from chispa.schema_comparer import SchemasNotEqualError
77
import math
8+
from pyspark.sql.types import StringType, IntegerType, StructType, StructField
89

910

1011
def describe_assert_df_equality():
@@ -17,7 +18,6 @@ def it_throws_with_schema_mismatches():
1718
assert_df_equality(df1, df2)
1819

1920

20-
2121
def it_can_work_with_different_row_orders():
2222
data1 = [(1, "jose"), (2, "li")]
2323
df1 = spark.createDataFrame(data1, ["num", "name"])
@@ -114,6 +114,46 @@ def it_does_not_consider_nan_values_equal_by_default():
114114
assert_df_equality(df1, df2, allow_nan_equality=False)
115115

116116

117+
def it_can_ignore_metadata():
118+
rows_data = [("jose", 1), ("li", 2), ("luisa", 3)]
119+
schema1 = StructType(
120+
[
121+
StructField("name", StringType(), True, {"hi": "no"}),
122+
StructField("age", IntegerType(), True),
123+
]
124+
)
125+
schema2 = StructType(
126+
[
127+
StructField("name", StringType(), True, {"hi": "whatever"}),
128+
StructField("age", IntegerType(), True),
129+
]
130+
)
131+
df1 = spark.createDataFrame(rows_data, schema1)
132+
df2 = spark.createDataFrame(rows_data, schema2)
133+
assert_df_equality(df1, df2, ignore_metadata=True)
134+
135+
136+
def it_catches_mismatched_metadata():
137+
rows_data = [("jose", 1), ("li", 2), ("luisa", 3)]
138+
schema1 = StructType(
139+
[
140+
StructField("name", StringType(), True, {"hi": "no"}),
141+
StructField("age", IntegerType(), True),
142+
]
143+
)
144+
schema2 = StructType(
145+
[
146+
StructField("name", StringType(), True, {"hi": "whatever"}),
147+
StructField("age", IntegerType(), True),
148+
]
149+
)
150+
df1 = spark.createDataFrame(rows_data, schema1)
151+
df2 = spark.createDataFrame(rows_data, schema2)
152+
with pytest.raises(SchemasNotEqualError) as e_info:
153+
assert_df_equality(df1, df2)
154+
155+
156+
117157
def describe_are_dfs_equal():
118158
def it_returns_false_with_schema_mismatches():
119159
data1 = [(1, "jose"), (2, "li"), (3, "laura")]

tests/test_schema_comparer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def it_returns_false_when_columns_have_different_order():
124124
assert are_schemas_equal_ignore_nullable(s1, s2) == False
125125

126126

127-
def describe_are_structfield_types_equal_ignore_nullable():
127+
def describe_are_structfields_equal():
128128
def it_returns_true_when_only_nullable_flag_is_different_within_array_element():
129129
s1 = StructField("coords", ArrayType(DoubleType(), True), True)
130130
s2 = StructField("coords", ArrayType(DoubleType(), False), True)
@@ -159,3 +159,22 @@ def it_returns_true_when_different_nullability_within_struct():
159159
s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True)
160160
s2 = StructField("coords", StructType([StructField("hello", DoubleType(), False)]), True)
161161
assert are_structfields_equal(s1, s2, True) == True
162+
def it_returns_false_when_metadata_differs():
163+
s1 = StructField("coords", StringType(), True, {"hi": "whatever"})
164+
s2 = StructField("coords", StringType(), True, {"hi": "no"})
165+
assert are_structfields_equal(s1, s2, ignore_nullability=True, ignore_metadata=False) is False
166+
167+
def it_allows_metadata_to_be_ignored():
168+
s1 = StructField("coords", StringType(), True, {"hi": "whatever"})
169+
s2 = StructField("coords", StringType(), True, {"hi": "no"})
170+
assert are_structfields_equal(s1, s2, ignore_nullability=False, ignore_metadata=True) is True
171+
172+
def it_allows_nullability_and_metadata_to_be_ignored():
173+
s1 = StructField("coords", StringType(), True, {"hi": "whatever"})
174+
s2 = StructField("coords", StringType(), False, {"hi": "no"})
175+
assert are_structfields_equal(s1, s2, ignore_nullability=True, ignore_metadata=True) is True
176+
177+
def it_returns_true_when_metadata_is_the_same():
178+
s1 = StructField("coords", StringType(), True, {"hi": "whatever"})
179+
s2 = StructField("coords", StringType(), True, {"hi": "whatever"})
180+
assert are_structfields_equal(s1, s2, ignore_nullability=True, ignore_metadata=False) is True

0 commit comments

Comments
 (0)