Skip to content

Commit 957febe

Browse files
committed
wip
1 parent 591e58f commit 957febe

File tree

3 files changed

+27
-49
lines changed

3 files changed

+27
-49
lines changed

chispa/default_formats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
class DefaultFormats:
55
mismatched_rows = ["red"]
66
matched_rows = ["blue"]
7-
mismatched_cells = ["white", "underline"]
7+
mismatched_cells = ["red", "underline"]
88
matched_cells = ["blue"]

chispa/rows_comparer.py

Lines changed: 23 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -40,61 +40,38 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats=Defa
4040
raise chispa.DataFramesNotEqualError("\n" + t.get_string())
4141

4242

43-
def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, underline_cells=False):
43+
def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, underline_cells=False, formats=DefaultFormats()):
4444
df1_rows = rows1
4545
df2_rows = rows2
4646
zipped = list(six.moves.zip_longest(df1_rows, df2_rows))
4747
t = PrettyTable(["df1", "df2"])
48-
allRowsEqual = True
49-
if underline_cells:
50-
row_column_names = rows1[0].__fields__
51-
num_columns = len(row_column_names)
48+
all_rows_equal = True
5249
for r1, r2 in zipped:
5350
# rows are not equal when one is None and the other isn't
5451
if (r1 is not None and r2 is None) or (r2 is not None and r1 is None):
55-
allRowsEqual = False
56-
t.add_row([r1, r2])
52+
all_rows_equal = False
53+
t.add_row([format_string(r1, formats.mismatched_rows), format_string(r2, formats.mismatched_rows)])
5754
# rows are equal
5855
elif row_equality_fun(r1, r2, *row_equality_fun_args):
59-
first = bcolors.LightBlue + str(r1) + bcolors.LightRed
60-
second = bcolors.LightBlue + str(r2) + bcolors.LightRed
61-
t.add_row([first, second])
56+
r1_string = ", ".join(map(lambda f: f"{f}={r1[f]}", r1.__fields__))
57+
r2_string = ", ".join(map(lambda f: f"{f}={r2[f]}", r2.__fields__))
58+
t.add_row([format_string(r1_string, formats.matched_rows), format_string(r2_string, formats.matched_rows)])
6259
# otherwise, rows aren't equal
6360
else:
64-
allRowsEqual = False
65-
# Underline cells if requested
66-
if underline_cells:
67-
t.add_row(__underline_cells_in_row(
68-
r1=r1, r2=r2, row_column_names=row_column_names, num_columns=num_columns))
69-
else:
70-
t.add_row([r1, r2])
71-
if allRowsEqual == False:
72-
raise chispa.DataFramesNotEqualError("\n" + t.get_string())
73-
74-
75-
def __underline_cells_in_row(r1=Row, r2=Row, row_column_names=List[str], num_columns=int) -> List[str]:
76-
"""
77-
Takes two Row types, a list of column names for the Rows and the length of columns
78-
Returns list of two strings, with underlined columns within rows that are different for PrettyTable
79-
"""
80-
r1_string = "Row("
81-
r2_string = "Row("
82-
for index, column in enumerate(row_column_names):
83-
if ((index+1) == num_columns):
84-
append_str = ""
85-
else:
86-
append_str = ", "
61+
r_zipped = list(six.moves.zip_longest(r1.__fields__, r2.__fields__))
62+
r1_string = []
63+
r2_string = []
64+
for r1_field, r2_field in r_zipped:
65+
if r1[r1_field] != r2[r2_field]:
66+
all_rows_equal = False
67+
r1_string.append(format_string(f"{r1_field}='{r1[r1_field]}'", formats.mismatched_cells))
68+
r2_string.append(format_string(f"{r2_field}='{r2[r2_field]}'", formats.mismatched_cells))
69+
else:
70+
r1_string.append(format_string(f"{r1_field}='{r1[r1_field]}'", formats.matched_cells))
71+
r2_string.append(format_string(f"{r2_field}='{r2[r2_field]}'", formats.matched_cells))
72+
r1_res = ", ".join(r1_string)
73+
r2_res = ", ".join(r2_string)
8774

88-
if r1[column] != r2[column]:
89-
r1_string += underline_text(
90-
f"{column}='{r1[column]}'") + f"{append_str}"
91-
r2_string += underline_text(
92-
f"{column}='{r2[column]}'") + f"{append_str}"
93-
else:
94-
r1_string += f"{column}='{r1[column]}'{append_str}"
95-
r2_string += f"{column}='{r2[column]}'{append_str}"
96-
97-
r1_string += ")"
98-
r2_string += ")"
99-
100-
return [bcolors.LightRed + r1_string, r2_string]
75+
t.add_row([r1_res, r2_res])
76+
if all_rows_equal == False:
77+
raise chispa.DataFramesNotEqualError("\n" + t.get_string())

tests/test_readme_examples.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def it_prints_underline_message():
151151
("rick", 66),
152152
]
153153
df2 = spark.createDataFrame(data, ["firstname", "age"])
154-
# with pytest.raises(DataFramesNotEqualError) as e_info:
155-
# assert_df_equality(df1, df2, formats=my_formats)
154+
with pytest.raises(DataFramesNotEqualError) as e_info:
155+
assert_df_equality(df1, df2, underline_cells=True)
156156

157157
def it_shows_assert_basic_rows_equality(my_formats):
158158
data = [
@@ -231,6 +231,7 @@ def test_approx_df_equality_different():
231231
(None, None)
232232
]
233233
df2 = spark.createDataFrame(data2, ["num", "letter"])
234+
# assert_approx_df_equality(df1, df2, 0.1)
234235
with pytest.raises(DataFramesNotEqualError) as e_info:
235236
assert_approx_df_equality(df1, df2, 0.1)
236237

0 commit comments

Comments
 (0)