-
Notifications
You must be signed in to change notification settings - Fork 76
Add allow_nan_equality option to assert_approx_df_equality #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
945443e
b94d4d5
e74368e
5fe032a
958b19c
34fd432
c3b907c
147bb57
3ff1f18
47f900a
461c978
0533904
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # Changelog | ||
|
|
||
| All notable changes to this project will be documented in this file. | ||
|
|
||
| The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), | ||
| and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||
|
|
||
| ## Unreleased | ||
|
|
||
| ### Changed | ||
| - `DataFramesNotEqualError` changed to `RowsNotEqualError` to reflect it being raised when testing for row equality. | ||
| - The assertion functions `assert_df_equality` and `assert_column_equality` now have optional `precision` parameter to test for approximate equality. | ||
|
|
||
| ### Removed | ||
| - Removed `are_dfs_equal` because it has been superseded by other parts of the API. | ||
| - Removed `assert_approx_df_equality` as it has been replaced by adding the optional `precision` parameter to `assert_df_equality`. | ||
| - Removed `assert_approx_column_equality` as it has been replaced by adding the optional `precision` parameter to `assert_column_equality`. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,53 +1,74 @@ | ||
| from chispa.bcolors import * | ||
| from typing import Optional, Any | ||
|
|
||
| from pyspark.sql import DataFrame | ||
| from pyspark.sql.types import DataType | ||
|
|
||
| from chispa.bcolors import blue | ||
| from chispa.prettytable import PrettyTable | ||
| from chispa.number_helpers import check_equal | ||
|
|
||
|
|
||
| class ColumnsNotEqualError(Exception): | ||
| """The columns are not equal""" | ||
| pass | ||
|
|
||
|
|
||
| def assert_column_equality(df, col_name1, col_name2): | ||
| elements = df.select(col_name1, col_name2).collect() | ||
| colName1Elements = list(map(lambda x: x[0], elements)) | ||
| colName2Elements = list(map(lambda x: x[1], elements)) | ||
| if colName1Elements != colName2Elements: | ||
| zipped = list(zip(colName1Elements, colName2Elements)) | ||
| t = PrettyTable([col_name1, col_name2]) | ||
| for elements in zipped: | ||
| if elements[0] == elements[1]: | ||
| first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed | ||
| second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed | ||
| t.add_row([first, second]) | ||
| else: | ||
| t.add_row([str(elements[0]), str(elements[1])]) | ||
| raise ColumnsNotEqualError("\n" + t.get_string()) | ||
| def assert_column_equality( | ||
| df: DataFrame, | ||
| col_name1: str, | ||
| col_name2: str, | ||
| precision: Optional[float] = None, | ||
| allow_nan_equality: bool = False, | ||
| ) -> None: | ||
| """Assert that two columns in a PySpark DataFrame are equal. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| precision : float, optional | ||
| Absolute tolerance when checking for equality. | ||
| allow_nan_equality : bool, default False | ||
| When True, treats two NaN values as equal. | ||
|
|
||
| def assert_approx_column_equality(df, col_name1, col_name2, precision): | ||
| elements = df.select(col_name1, col_name2).collect() | ||
| colName1Elements = list(map(lambda x: x[0], elements)) | ||
| colName2Elements = list(map(lambda x: x[1], elements)) | ||
| """ | ||
| all_rows_equal = True | ||
| zipped = list(zip(colName1Elements, colName2Elements)) | ||
| t = PrettyTable([col_name1, col_name2]) | ||
|
|
||
| # Zip both columns together for iterating through elements. | ||
| columns = df.select(col_name1, col_name2).collect() | ||
| zipped = zip(*[list(map(lambda x: x[i], columns)) for i in [0, 1]]) | ||
|
|
||
| for elements in zipped: | ||
| first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed | ||
| second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed | ||
| # when one is None and the other isn't, they're not equal | ||
| if (elements[0] == None and elements[1] != None) or (elements[0] != None and elements[1] == None): | ||
| all_rows_equal = False | ||
| t.add_row([str(elements[0]), str(elements[1])]) | ||
| # when both are None, they're equal | ||
| elif elements[0] == None and elements[1] == None: | ||
| t.add_row([first, second]) | ||
| # when the diff is less than the threshhold, they're approximately equal | ||
| elif abs(elements[0] - elements[1]) < precision: | ||
| t.add_row([first, second]) | ||
| # otherwise, they're not equal | ||
| if are_elements_equal(*elements, precision, allow_nan_equality): | ||
| t.add_row([blue(e) for e in elements]) | ||
| else: | ||
| all_rows_equal = False | ||
| t.add_row([str(elements[0]), str(elements[1])]) | ||
| t.add_row([str(e) for e in elements]) | ||
|
|
||
| if all_rows_equal == False: | ||
| raise ColumnsNotEqualError("\n" + t.get_string()) | ||
|
|
||
|
|
||
| def are_elements_equal( | ||
| e1: DataType, | ||
| e2: DataType, | ||
| precision: Optional[float] = None, | ||
| allow_nan_equality: bool = False, | ||
| ) -> bool: | ||
| """ | ||
| Return True if both elements are equal. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| precision : float, optional | ||
| Absolute tolerance when checking for equality. | ||
| allow_nan_equality: bool, default False | ||
| When True, treats two NaN values as equal. | ||
|
|
||
| """ | ||
| # If both elements are None they are considered equal. | ||
| if e1 is None and e2 is None: | ||
| return True | ||
| if (e1 is None and e2 is not None) or (e2 is None and e1 is not None): | ||
| return False | ||
|
|
||
| return check_equal(e1, e2, precision, allow_nan_equality) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,78 +1,51 @@ | ||
| from chispa.prettytable import PrettyTable | ||
| from chispa.bcolors import * | ||
| from chispa.schema_comparer import assert_schema_equality | ||
| from chispa.row_comparer import * | ||
| import chispa.six as six | ||
| from functools import reduce | ||
| from typing import Callable, Optional | ||
|
|
||
| from pyspark.sql import DataFrame | ||
|
|
||
| class DataFramesNotEqualError(Exception): | ||
| """The DataFrames are not equal""" | ||
| pass | ||
|
|
||
|
|
||
| def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False, ignore_column_order=False, ignore_row_order=False): | ||
| from chispa.schema_comparer import assert_schema_equality | ||
| from chispa.row_comparer import assert_rows_equality | ||
|
|
||
|
|
||
| def assert_df_equality( | ||
| df1: DataFrame, | ||
| df2: DataFrame, | ||
| precision: Optional[float] = None, | ||
| ignore_nullable: bool = False, | ||
| allow_nan_equality: bool = False, | ||
| ignore_column_order: bool = False, | ||
| ignore_row_order: bool = False, | ||
| transforms: Callable[[DataFrame], DataFrame] = None, | ||
| ) -> None: | ||
| """Assert that two PySpark DataFrames are equal. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| precision : float, optional | ||
| Absolute tolerance when checking for equality. | ||
| ignore_nullable : bool, default False | ||
| Ignore nullable option when comparing schemas. | ||
| allow_nan_equality : bool, default False | ||
| When True, treats two NaN values as equal. | ||
| ignore_column_order : bool, default False | ||
| When True, sorts columns before comparing. | ||
| ignore_row_order : bool, default False | ||
| When True, sorts all rows before comparing. | ||
| transforms : callable | ||
| Additional transforms to make to DataFrame before comparison. | ||
|
|
||
| """ | ||
| # Apply row and column order transforms + custom transforms. | ||
| if transforms is None: | ||
| transforms = [] | ||
| if ignore_column_order: | ||
| transforms.append(lambda df: df.select(sorted(df.columns))) | ||
| if ignore_row_order: | ||
| transforms.append(lambda df: df.sort(df.columns)) | ||
|
|
||
| df1 = reduce(lambda acc, fn: fn(acc), transforms, df1) | ||
| df2 = reduce(lambda acc, fn: fn(acc), transforms, df2) | ||
| assert_schema_equality(df1.schema, df2.schema, ignore_nullable) | ||
| if allow_nan_equality: | ||
| assert_generic_rows_equality(df1, df2, are_rows_equal_enhanced, [True]) | ||
| else: | ||
| assert_basic_rows_equality(df1, df2) | ||
|
|
||
|
|
||
| def are_dfs_equal(df1, df2): | ||
| if df1.schema != df2.schema: | ||
| return False | ||
| if df1.collect() != df2.collect(): | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False): | ||
| # Check schema and row equality. | ||
| assert_schema_equality(df1.schema, df2.schema, ignore_nullable) | ||
| assert_generic_rows_equality(df1, df2, are_rows_approx_equal, [precision]) | ||
|
|
||
|
|
||
| def assert_generic_rows_equality(df1, df2, row_equality_fun, row_equality_fun_args): | ||
| df1_rows = df1.collect() | ||
| df2_rows = df2.collect() | ||
| zipped = list(six.moves.zip_longest(df1_rows, df2_rows)) | ||
| t = PrettyTable(["df1", "df2"]) | ||
| allRowsEqual = True | ||
| for r1, r2 in zipped: | ||
| # rows are not equal when one is None and the other isn't | ||
| if (r1 is not None and r2 is None) or (r2 is not None and r1 is None): | ||
| allRowsEqual = False | ||
| t.add_row([r1, r2]) | ||
| # rows are equal | ||
| elif row_equality_fun(r1, r2, *row_equality_fun_args): | ||
| first = bcolors.LightBlue + str(r1) + bcolors.LightRed | ||
| second = bcolors.LightBlue + str(r2) + bcolors.LightRed | ||
| t.add_row([first, second]) | ||
| # otherwise, rows aren't equal | ||
| else: | ||
| allRowsEqual = False | ||
| t.add_row([r1, r2]) | ||
| if allRowsEqual == False: | ||
| raise DataFramesNotEqualError("\n" + t.get_string()) | ||
|
|
||
|
|
||
| def assert_basic_rows_equality(df1, df2): | ||
| rows1 = df1.collect() | ||
| rows2 = df2.collect() | ||
| if rows1 != rows2: | ||
| t = PrettyTable(["df1", "df2"]) | ||
| zipped = list(six.moves.zip_longest(rows1, rows2)) | ||
| for r1, r2 in zipped: | ||
| if r1 == r2: | ||
| t.add_row([blue(r1), blue(r2)]) | ||
| else: | ||
| t.add_row([r1, r2]) | ||
| raise DataFramesNotEqualError("\n" + t.get_string()) | ||
| assert_rows_equality(df1, df2, precision, allow_nan_equality) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import math | ||
| from typing import Optional | ||
|
|
||
|
|
||
| def isnan(x): | ||
|
|
@@ -8,5 +9,27 @@ def isnan(x): | |
| return False | ||
|
|
||
|
|
||
| def nan_safe_equality(x, y) -> bool: | ||
| return (x == y) or (isnan(x) and isnan(y)) | ||
| def check_equal( | ||
| x, y, | ||
| precision: Optional[float] = None, | ||
| allow_nan_equality: bool = False, | ||
| ) -> bool: | ||
| """Return True if x and y are equal. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| precision : float, optional | ||
| Absolute tolerance when checking for equality. | ||
| allow_nan_equality: bool, defaults to False | ||
| When True, treats two NaN values as equal. | ||
|
|
||
| """ | ||
| both_floats = (isinstance(x, float) & isinstance(y, float)) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to do this cause we know the types from the schema.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @MrPowers , can you have a look at my draft pull request here to see if I'm going along the right lines with your 3rd objective? To summarise, rather than the following line in both_floats = (isinstance(x, float) & isinstance(y, float))I'm doing this: is_float_type = (dtype_name in ['float', 'double', 'decimal'])Where dtype gets passed into the function, and is created with the following line: dtypes = [field.dataType.typeName() for field in df1.schema]Because we've already compared the schemas, we know that they're equal so we just use Will using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just did some further testing too, and my current solution would not work for pyspark columns with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any time to give me a steer this week @MrPowers ? As in point me in the right direction for the solution (sorry realised giving a steer might be a localised phrase) - where are you from anyway? |
||
| if (precision is not None) & both_floats: | ||
| both_equal = abs(x - y) < precision | ||
| else: | ||
| both_equal = (x == y) | ||
|
|
||
| both_nan = (isnan(x) and isnan(y)) if allow_nan_equality else False | ||
|
|
||
| return both_equal or both_nan | ||
Uh oh!
There was an error while loading. Please reload this page.