|
1 | | -from abc import ABC, abstractmethod |
2 | | -from typing import cast |
3 | | - |
4 | | -from pandas import DataFrame as PandasDataFrame |
5 | | -from pyspark.sql import DataFrame as PySparkDataFrame |
6 | | - |
7 | | -# Import the connect DataFrame type for Spark Connect |
8 | | -try: |
9 | | - from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame |
10 | | -except ImportError: |
11 | | - # Fallback for older PySpark versions that don't have connect |
12 | | - PySparkConnectDataFrame = None # type: ignore[misc,assignment] |
13 | | - |
14 | | -from dataframe_expectations.core.types import DataFrameLike, DataFrameType |
15 | | -from dataframe_expectations.result_message import ( |
16 | | - DataFrameExpectationResultMessage, |
17 | | -) |
18 | | - |
19 | | - |
20 | | -class DataFrameExpectation(ABC): |
21 | | - """ |
22 | | - Base class for DataFrame expectations. |
23 | | - """ |
24 | | - |
25 | | - def get_expectation_name(self) -> str: |
26 | | - """ |
27 | | - Returns the class name as the expectation name. |
28 | | - """ |
29 | | - return type(self).__name__ |
30 | | - |
31 | | - @abstractmethod |
32 | | - def get_description(self) -> str: |
33 | | - """ |
34 | | - Returns a description of the expectation. |
35 | | - """ |
36 | | - raise NotImplementedError( |
37 | | - f"description method must be implemented for {self.__class__.__name__}" |
38 | | - ) |
39 | | - |
40 | | - def __str__(self): |
41 | | - """ |
42 | | - Returns a string representation of the expectation. |
43 | | - """ |
44 | | - return f"{self.get_expectation_name()} ({self.get_description()})" |
45 | | - |
46 | | - @classmethod |
47 | | - def infer_data_frame_type(cls, data_frame: DataFrameLike) -> DataFrameType: |
48 | | - """ |
49 | | - Infer the DataFrame type based on the provided DataFrame. |
50 | | - """ |
51 | | - if isinstance(data_frame, PandasDataFrame): |
52 | | - return DataFrameType.PANDAS |
53 | | - elif isinstance(data_frame, PySparkDataFrame): |
54 | | - return DataFrameType.PYSPARK |
55 | | - elif PySparkConnectDataFrame is not None and isinstance( |
56 | | - data_frame, PySparkConnectDataFrame |
57 | | - ): |
58 | | - return DataFrameType.PYSPARK |
59 | | - else: |
60 | | - raise ValueError(f"Unsupported DataFrame type: {type(data_frame)}") |
61 | | - |
62 | | - def validate(self, data_frame: DataFrameLike, **kwargs): |
63 | | - """ |
64 | | - Validate the DataFrame against the expectation. |
65 | | - """ |
66 | | - data_frame_type = self.infer_data_frame_type(data_frame) |
67 | | - |
68 | | - if data_frame_type == DataFrameType.PANDAS: |
69 | | - return self.validate_pandas(data_frame=data_frame, **kwargs) |
70 | | - elif data_frame_type == DataFrameType.PYSPARK: |
71 | | - return self.validate_pyspark(data_frame=data_frame, **kwargs) |
72 | | - else: |
73 | | - raise ValueError(f"Unsupported DataFrame type: {data_frame_type}") |
74 | | - |
75 | | - @abstractmethod |
76 | | - def validate_pandas( |
77 | | - self, data_frame: DataFrameLike, **kwargs |
78 | | - ) -> DataFrameExpectationResultMessage: |
79 | | - """ |
80 | | - Validate a pandas DataFrame against the expectation. |
81 | | - """ |
82 | | - raise NotImplementedError( |
83 | | - f"validate_pandas method must be implemented for {self.__class__.__name__}" |
84 | | - ) |
85 | | - |
86 | | - @abstractmethod |
87 | | - def validate_pyspark( |
88 | | - self, data_frame: DataFrameLike, **kwargs |
89 | | - ) -> DataFrameExpectationResultMessage: |
90 | | - """ |
91 | | - Validate a PySpark DataFrame against the expectation. |
92 | | - """ |
93 | | - raise NotImplementedError( |
94 | | - f"validate_pyspark method must be implemented for {self.__class__.__name__}" |
95 | | - ) |
96 | | - |
97 | | - @classmethod |
98 | | - def num_data_frame_rows(cls, data_frame: DataFrameLike) -> int: |
99 | | - """ |
100 | | - Count the number of rows in the DataFrame. |
101 | | - """ |
102 | | - data_frame_type = cls.infer_data_frame_type(data_frame) |
103 | | - if data_frame_type == DataFrameType.PANDAS: |
104 | | - # Cast to PandasDataFrame since we know it's a Pandas DataFrame at this point |
105 | | - return len(cast(PandasDataFrame, data_frame)) |
106 | | - elif data_frame_type == DataFrameType.PYSPARK: |
107 | | - # Cast to PySparkDataFrame since we know it's a PySpark DataFrame at this point |
108 | | - return cast(PySparkDataFrame, data_frame).count() |
109 | | - else: |
110 | | - raise ValueError(f"Unsupported DataFrame type: {data_frame_type}") |
| 1 | +"""Expectations package - contains all expectation implementations.""" |
0 commit comments