diff --git a/src/data_tools/analysis/models.py b/src/data_tools/analysis/models.py index 0b1adbb..f19ffa0 100644 --- a/src/data_tools/analysis/models.py +++ b/src/data_tools/analysis/models.py @@ -4,12 +4,17 @@ from typing import Any, Dict, Optional +import pandas as pd import yaml from data_tools.common.exception import errors from data_tools.core import settings from data_tools.dataframes.factory import DataFrameFactory -from data_tools.dataframes.models import ColumnProfile +from data_tools.dataframes.models import ( + ColumnProfile, + DataTypeIdentificationL2Input, + ProfilingOutput, +) from data_tools.models.resources.model import Column, ColumnProfilingMetrics from data_tools.models.resources.source import Source, SourceTables @@ -31,6 +36,126 @@ def __init__(self, df: Any, name: str): # A dictionary to store the results of each analysis step self.results: Dict[str, Any] = {} + + def profile_table(self) -> "DataSet": + """ + Profiles the table and stores the result in the 'results' dictionary. + """ + self.results["table_profile"] = self.dataframe_wrapper.profile(self.raw_df) + return self + + def profile_columns(self) -> dict[str, ColumnProfile]: + """ + Profiles each column in the dataset and stores the results in the 'results' dictionary. + This method relies on the 'table_profile' result to get the list of columns. + """ + if "table_profile" not in self.results: + raise RuntimeError("TableProfiler must be run before profiling columns.") + + table_profile: ProfilingOutput = self.results["table_profile"] + self.results["column_profiles"] = { + col_name: self.dataframe_wrapper.column_profile( + self.raw_df, self.name, col_name, settings.UPSTREAM_SAMPLE_LIMIT + ) + for col_name in table_profile.columns + } + return self + + def identify_datatypes_l1(self) -> "DataSet": + """ + Identifies the data types at Level 1 for each column based on the column profiles. + This method relies on the 'column_profiles' result. + """ + if "column_profiles" not in self.results: + raise RuntimeError("TableProfiler and ColumnProfiler must be run before data type identification.") + + column_profiles: dict[str, ColumnProfile] = self.results["column_profiles"] + column_datatypes_l1 = self.dataframe_wrapper.datatype_identification_l1(self.raw_df, self.name, column_profiles) + + for column in column_datatypes_l1: + column_profiles[column.column_name].datatype_l1 = column.datatype_l1 + + self.results["column_datatypes_l1"] = column_datatypes_l1 + return self + + def identify_datatypes_l2(self) -> "DataSet": + """ + Identifies the data types at Level 2 for each column based on the column profiles. + This method relies on the 'column_profiles' result. + """ + if "column_profiles" not in self.results: + raise RuntimeError("TableProfiler and ColumnProfiler must be run before data type identification.") + + column_profiles: dict[str, ColumnProfile] = self.results["column_profiles"] + columns_with_samples = [DataTypeIdentificationL2Input(**col.model_dump()) for col in column_profiles.values()] + column_datatypes_l2 = self.dataframe_wrapper.datatype_identification_l2( + self.raw_df, self.name, columns_with_samples + ) + + for column in column_datatypes_l2: + column_profiles[column.column_name].datatype_l2 = column.datatype_l2 + + self.results["column_datatypes_l2"] = column_datatypes_l2 + return self + + def identify_keys(self) -> "DataSet": + """ + Identifies potential primary keys in the dataset based on column profiles. + This method relies on the 'column_profiles' result. + """ + if "column_datatypes_l1" not in self.results or "column_datatypes_l2" not in self.results: + raise RuntimeError("DataTypeIdentifierL1 and L2 must be run before KeyIdentifier.") + + column_profiles: dict[str, ColumnProfile] = self.results["column_profiles"] + column_profiles_df = pd.DataFrame([col.model_dump() for col in column_profiles.values()]) + + key = self.dataframe_wrapper.key_identification(self.name, column_profiles_df) + if key is not None: + self.results["key"] = key + return self + + def profile(self) -> None: + """ + Profiles the dataset including table and columns and stores the result in the 'results' dictionary. + This is a convenience method to run profiling on the raw dataframe. + """ + if self.raw_df.empty: + raise ValueError("The raw dataframe is empty. Cannot perform profiling.") + self.profile_table().profile_columns() + return self + + def identify_datatypes(self) -> None: + """ + Identifies the data types for the dataset and stores the result in the 'results' dictionary. + This is a convenience method to run data type identification on the raw dataframe. + """ + if self.raw_df.empty: + raise ValueError("The raw dataframe is empty. Cannot perform data type identification.") + self.identify_datatypes_l1().identify_datatypes_l2() + return self + + def generate_glossary(self, domain: str = "") -> "DataSet": + """ + Generates a business glossary for the dataset and stores the result in the 'results' dictionary. + This method relies on the 'column_datatypes_l1' results. + """ + if "column_datatypes_l1" not in self.results: + raise RuntimeError("DataTypeIdentifierL1 must be run before Business Glossary Generation.") + + column_profiles: dict[str, ColumnProfile] = self.results["column_profiles"] + column_profiles_df = pd.DataFrame([col.model_dump() for col in column_profiles.values()]) + + glossary_output = self.dataframe_wrapper.generate_business_glossary( + self.name, column_profiles_df, domain=domain + ) + + for column in glossary_output.columns: + column_profiles[column.column_name].business_glossary = column.business_glossary + column_profiles[column.column_name].business_tags = column.business_tags + + self.results["business_glossary_and_tags"] = glossary_output + self.results["table_glossary"] = glossary_output.table_glossary + return self # FIXME - this is a temporary solution to save the results of the analysis # need to use model while executing the pipeline @@ -77,3 +202,10 @@ def save_yaml(self, file_path: Optional[str] = None) -> None: # Save the YAML representation of the sources with open(file_path, "w") as file: yaml.dump(sources, file, sort_keys=False, default_flow_style=False) + + def _repr_html_(self): + column_profiles = self.results.get("column_profiles") + if column_profiles is None: + return "

No column profiles available.

" + df = pd.DataFrame([col.model_dump() for col in column_profiles.values()]) + return df._repr_html_() diff --git a/src/data_tools/analysis/steps.py b/src/data_tools/analysis/steps.py index 3e5559b..9c6b6be 100644 --- a/src/data_tools/analysis/steps.py +++ b/src/data_tools/analysis/steps.py @@ -1,16 +1,7 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -import pandas as pd - -from data_tools.core.settings import settings -from data_tools.dataframes.models import DataTypeIdentificationL2Input from .models import DataSet -if TYPE_CHECKING: - from data_tools.dataframes.models import ColumnProfile, ProfilingOutput - class AnalysisStep(ABC): """Abstract base class for any step in our analysis pipeline.""" @@ -28,8 +19,7 @@ def analyze(self, dataset: DataSet) -> None: """ Performs table-level profiling and saves the result. """ - profile = dataset.dataframe_wrapper.profile(dataset.raw_df) - dataset.results["table_profile"] = profile + dataset.profile_table() class ColumnProfiler(AnalysisStep): @@ -38,22 +28,7 @@ def analyze(self, dataset: DataSet) -> None: Performs column-level profiling for each column. This step depends on the 'table_profile' result. """ - - # Dependency check - if "table_profile" not in dataset.results: - raise RuntimeError("TableProfiler must be run before ColumnProfiler.") - - table_profile: ProfilingOutput = dataset.results["table_profile"] - all_column_profiles = {} - - for col_name in table_profile.columns: - # We would add a method to our DataFrame wrapper to get stats for a single column - stats = dataset.dataframe_wrapper.column_profile( - dataset.raw_df, dataset.name, col_name, settings.UPSTREAM_SAMPLE_LIMIT - ) - all_column_profiles[col_name] = stats - - dataset.results["column_profiles"] = all_column_profiles + dataset.profile_columns() class DataTypeIdentifierL1(AnalysisStep): @@ -62,21 +37,7 @@ def analyze(self, dataset: DataSet) -> None: Performs datatype identification level 1 for each column. This step depends on the 'column_profiles' result. """ - - # Dependency check - if "column_profiles" not in dataset.results: - raise RuntimeError("TableProfiler and ColumnProfiler must be run before DatatypeIdentifierL1.") - - column_profiles: dict[str, ColumnProfile] = dataset.results["column_profiles"] - - column_datatypes_l1 = dataset.dataframe_wrapper.datatype_identification_l1( - dataset.raw_df, dataset.name, column_profiles - ) - - for column in column_datatypes_l1: - column_profiles[column.column_name].datatype_l1 = column.datatype_l1 - - dataset.results["column_datatypes_l1"] = column_datatypes_l1 + dataset.identify_datatypes_l1() class DataTypeIdentifierL2(AnalysisStep): @@ -85,21 +46,7 @@ def analyze(self, dataset: DataSet) -> None: Performs datatype identification level 2 for each column. This step depends on the 'column_datatypes_l1' result. """ - - # Dependency check - if "column_profiles" not in dataset.results: - raise RuntimeError("TableProfiler and ColumnProfiler must be run before DatatypeIdentifierL2.") - - column_profiles: dict[str, ColumnProfile] = dataset.results["column_profiles"] - columns_with_samples = [DataTypeIdentificationL2Input(**col.model_dump()) for col in column_profiles.values()] - column_datatypes_l2 = dataset.dataframe_wrapper.datatype_identification_l2( - dataset.raw_df, dataset.name, columns_with_samples - ) - - for column in column_datatypes_l2: - column_profiles[column.column_name].datatype_l2 = column.datatype_l2 - - dataset.results["column_datatypes_l2"] = column_datatypes_l2 + dataset.identify_datatypes_l2() class KeyIdentifier(AnalysisStep): @@ -108,15 +55,7 @@ def analyze(self, dataset: DataSet) -> None: Performs key identification for the dataset. This step depends on the datatype identification results. """ - if "column_datatypes_l1" not in dataset.results or "column_datatypes_l2" not in dataset.results: - raise RuntimeError("DataTypeIdentifierL1 and L2 must be run before KeyIdentifier.") - - column_profiles: dict[str, ColumnProfile] = dataset.results["column_profiles"] - column_profiles_df = pd.DataFrame([col.model_dump() for col in column_profiles.values()]) - - key = dataset.dataframe_wrapper.key_identification(dataset.name, column_profiles_df) - if key is not None: - dataset.results["key"] = key + dataset.identify_keys() class BusinessGlossaryGenerator(AnalysisStep): @@ -131,19 +70,4 @@ def analyze(self, dataset: DataSet) -> None: """ Generates business glossary terms and tags for each column in the dataset. """ - if "column_datatypes_l1" not in dataset.results: - raise RuntimeError("DataTypeIdentifierL1 must be run before Business Glossary Generation.") - - column_profiles: dict[str, ColumnProfile] = dataset.results["column_profiles"] - column_profiles_df = pd.DataFrame([col.model_dump() for col in column_profiles.values()]) - - glossary_output = dataset.dataframe_wrapper.generate_business_glossary( - dataset.name, column_profiles_df, domain=self.domain - ) - - for column in glossary_output.columns: - column_profiles[column.column_name].business_glossary = column.business_glossary - column_profiles[column.column_name].business_tags = column.business_tags - - dataset.results["business_glossary_and_tags"] = glossary_output - dataset.results["table_glossary"] = glossary_output.table_glossary + dataset.generate_glossary(self.domain) diff --git a/tests/analysis/test_high_level.py b/tests/analysis/test_high_level.py new file mode 100644 index 0000000..ace014d --- /dev/null +++ b/tests/analysis/test_high_level.py @@ -0,0 +1,94 @@ +import pandas as pd +import pytest + +from data_tools.analysis.models import DataSet + + +@pytest.fixture +def sample_dataframe(): + """Fixture to provide a sample DataFrame for testing.""" + return pd.DataFrame({ + "user_id": [1, 2, 3, 4, 5], + "product_name": ["Laptop", "Mouse", "Keyboard", "Monitor", "Webcam"], + "price": [1200.50, 25.00, 75.99, 300.00, 55.50], + "purchase_date": pd.to_datetime([ + "2023-01-15", "2023-01-16", "2023-01-17", "2023-01-18", "2023-01-19" + ]), + }) + + +def test_profile(sample_dataframe): + """Test the profile convenience method.""" + dataset = DataSet(df=sample_dataframe, name="test_table") + dataset.profile() + + assert "table_profile" in dataset.results + table_profile = dataset.results["table_profile"] + assert table_profile is not None + assert table_profile.count == 5 + assert set(table_profile.columns) == {"user_id", "product_name", "price", "purchase_date"} + + assert "column_profiles" in dataset.results + column_profiles = dataset.results["column_profiles"] + assert column_profiles is not None + assert len(column_profiles) == 4 + + +def test_identify_datatypes(sample_dataframe): + """Test the identify_datatypes convenience method.""" + dataset = DataSet(df=sample_dataframe, name="test_table") + dataset.profile() + dataset.identify_datatypes() + + assert "column_datatypes_l1" in dataset.results + column_datatypes_l1 = dataset.results["column_datatypes_l1"] + assert column_datatypes_l1 is not None + assert len(column_datatypes_l1) == 4 + + assert "column_datatypes_l2" in dataset.results + column_datatypes_l2 = dataset.results["column_datatypes_l2"] + assert column_datatypes_l2 is not None + assert len(column_datatypes_l2) == 4 + + +def test_identify_keys(sample_dataframe): + """Test the identify_keys method.""" + dataset = DataSet(df=sample_dataframe, name="test_table") + dataset.profile() + dataset.identify_datatypes() + dataset.identify_keys() + + assert "key" in dataset.results + key = dataset.results["key"] + assert key is not None + + +def test_generate_glossary(sample_dataframe): + """Test the generate_glossary method.""" + dataset = DataSet(df=sample_dataframe, name="test_table") + dataset.profile() + dataset.identify_datatypes() + dataset.generate_glossary(domain="ecommerce") + + assert "business_glossary_and_tags" in dataset.results + glossary = dataset.results["business_glossary_and_tags"] + assert glossary is not None + assert "table_glossary" in dataset.results + table_glossary = dataset.results["table_glossary"] + assert table_glossary is not None + + +def test_save_yaml(sample_dataframe, tmp_path): + """Test the save_yaml method.""" + dataset = DataSet(df=sample_dataframe, name="test_table") + dataset.profile() + dataset.identify_datatypes() + dataset.generate_glossary(domain="ecommerce") + + file_path = tmp_path / "test_table.yml" + dataset.save_yaml(file_path=str(file_path)) + + assert file_path.exists() + with open(file_path, "r") as file: + content = file.read() + assert "sources" in content diff --git a/tests/parser/test_manifest.py b/tests/parser/test_manifest.py index 87c2316..f7d946a 100644 --- a/tests/parser/test_manifest.py +++ b/tests/parser/test_manifest.py @@ -1,10 +1,9 @@ +from data_tools.core import settings from data_tools.parser.manifest import ManifestLoader -PROJECT_BASE = "/home/juhel-phanju/Documents/backup/MIGRATION/codes/poc/dbt/ecom/ecom/models" - def test_manifet(): - manifest_loader = ManifestLoader(PROJECT_BASE) + manifest_loader = ManifestLoader(settings.PROJECT_BASE) manifest_loader.load() assert manifest_loader.manifest is not None