From 4863600237e4cafdc2c128590a989bbe38ece38c Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Wed, 6 Nov 2024 11:02:03 +0800 Subject: [PATCH] Support sql check --- setup.py | 1 + .../core/insights/sql/base/insight.py | 11 +- src/datapilot/core/platforms/dbt/executor.py | 2 + .../core/platforms/dbt/insights/__init__.py | 2 + .../core/platforms/dbt/insights/base.py | 3 + .../platforms/dbt/insights/sql/__init__.py | 0 .../core/platforms/dbt/insights/sql/base.py | 23 ++++ .../platforms/dbt/insights/sql/sql_check.py | 101 ++++++++++++++++++ .../dbt/wrappers/manifest/v10/wrapper.py | 5 + .../dbt/wrappers/manifest/v11/wrapper.py | 5 + .../dbt/wrappers/manifest/v12/wrapper.py | 5 + .../dbt/wrappers/manifest/wrapper.py | 5 + 12 files changed, 154 insertions(+), 9 deletions(-) create mode 100644 src/datapilot/core/platforms/dbt/insights/sql/__init__.py create mode 100644 src/datapilot/core/platforms/dbt/insights/sql/base.py create mode 100644 src/datapilot/core/platforms/dbt/insights/sql/sql_check.py diff --git a/setup.py b/setup.py index 7551cc27..3dbf5ca6 100755 --- a/setup.py +++ b/setup.py @@ -67,6 +67,7 @@ def read(*names, **kwargs): "ruamel.yaml==0.18.6", "tabulate==0.9.0", "requests==2.31.0", + "sqlglot==25.30.0", ], extras_require={ # eg: diff --git a/src/datapilot/core/insights/sql/base/insight.py b/src/datapilot/core/insights/sql/base/insight.py index 6e452f0b..14baeb1c 100644 --- a/src/datapilot/core/insights/sql/base/insight.py +++ b/src/datapilot/core/insights/sql/base/insight.py @@ -1,18 +1,11 @@ from abc import abstractmethod -from typing import Optional -from datapilot.core.insights.base.insight import Insight -from datapilot.schemas.sql import Dialect +from datapilot.core.platforms.dbt.insights.checks.base import ChecksInsight -class SqlInsight(Insight): +class SqlInsight(ChecksInsight): NAME = "SqlInsight" - def __init__(self, sql: str, dialect: Optional[Dialect], *args, **kwargs): - self.sql = sql - self.dialect = dialect - super().__init__(*args, **kwargs) - @abstractmethod def generate(self, *args, **kwargs) -> dict: pass diff --git a/src/datapilot/core/platforms/dbt/executor.py b/src/datapilot/core/platforms/dbt/executor.py index b6989626..13544eff 100644 --- a/src/datapilot/core/platforms/dbt/executor.py +++ b/src/datapilot/core/platforms/dbt/executor.py @@ -51,6 +51,7 @@ def __init__( self.macros = self.manifest_wrapper.get_macros() self.sources = self.manifest_wrapper.get_sources() self.exposures = self.manifest_wrapper.get_exposures() + self.adapter_type = self.manifest_wrapper.get_adapter_type() self.seeds = self.manifest_wrapper.get_seeds() self.children_map = self.manifest_wrapper.parent_to_child_map(self.nodes) self.tests = self.manifest_wrapper.get_tests() @@ -112,6 +113,7 @@ def run(self): children_map=self.children_map, tests=self.tests, project_name=self.project_name, + adapter_type=self.adapter_type, config=self.config, selected_models=self.selected_models, excluded_models=self.excluded_models, diff --git a/src/datapilot/core/platforms/dbt/insights/__init__.py b/src/datapilot/core/platforms/dbt/insights/__init__.py index d903bb7f..c7cf6b0c 100644 --- a/src/datapilot/core/platforms/dbt/insights/__init__.py +++ b/src/datapilot/core/platforms/dbt/insights/__init__.py @@ -51,6 +51,7 @@ from datapilot.core.platforms.dbt.insights.modelling.unused_sources import DBTUnusedSources from datapilot.core.platforms.dbt.insights.performance.chain_view_linking import DBTChainViewLinking from datapilot.core.platforms.dbt.insights.performance.exposure_parent_materializations import DBTExposureParentMaterialization +from datapilot.core.platforms.dbt.insights.sql.sql_check import SqlCheck from datapilot.core.platforms.dbt.insights.structure.model_directories_structure import DBTModelDirectoryStructure from datapilot.core.platforms.dbt.insights.structure.model_naming_conventions import DBTModelNamingConvention from datapilot.core.platforms.dbt.insights.structure.source_directories_structure import DBTSourceDirectoryStructure @@ -112,4 +113,5 @@ CheckSourceHasTests, CheckSourceTableHasDescription, CheckSourceTags, + SqlCheck, ] diff --git a/src/datapilot/core/platforms/dbt/insights/base.py b/src/datapilot/core/platforms/dbt/insights/base.py index 50f0c5a7..c4009090 100644 --- a/src/datapilot/core/platforms/dbt/insights/base.py +++ b/src/datapilot/core/platforms/dbt/insights/base.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing import Dict from typing import List +from typing import Optional from typing import Union from datapilot.config.utils import get_insight_config @@ -33,6 +34,7 @@ def __init__( macros: Dict[str, AltimateManifestMacroNode], children_map: Dict[str, List[str]], project_name: str, + adapter_type: Optional[str], selected_models: Union[List[str], None] = None, excluded_models: Union[List[str], None] = None, *args, @@ -47,6 +49,7 @@ def __init__( self.seeds = seeds self.children_map = children_map self.project_name = project_name + self.adapter_type = adapter_type self.selected_models = selected_models self.excluded_models = excluded_models super().__init__(*args, **kwargs) diff --git a/src/datapilot/core/platforms/dbt/insights/sql/__init__.py b/src/datapilot/core/platforms/dbt/insights/sql/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datapilot/core/platforms/dbt/insights/sql/base.py b/src/datapilot/core/platforms/dbt/insights/sql/base.py new file mode 100644 index 00000000..22d6e49d --- /dev/null +++ b/src/datapilot/core/platforms/dbt/insights/sql/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from typing import Tuple + +from datapilot.core.platforms.dbt.insights.base import DBTInsight + + +class SqlInsight(DBTInsight): + TYPE = "governance" + + @abstractmethod + def generate(self, *args, **kwargs) -> dict: + pass + + @classmethod + def has_all_required_data(cls, has_manifest: bool, **kwargs) -> Tuple[bool, str]: + """ + Check if all required data is available for the insight to run. + :param has_manifest: A boolean indicating if manifest is available. + :return: A boolean indicating if all required data is available. + """ + if not has_manifest: + return False, "manifest is required for insight to run." + return True, "" diff --git a/src/datapilot/core/platforms/dbt/insights/sql/sql_check.py b/src/datapilot/core/platforms/dbt/insights/sql/sql_check.py new file mode 100644 index 00000000..16774b69 --- /dev/null +++ b/src/datapilot/core/platforms/dbt/insights/sql/sql_check.py @@ -0,0 +1,101 @@ +import inspect +from typing import List + +from sqlglot import parse_one +from sqlglot.optimizer.eliminate_ctes import eliminate_ctes +from sqlglot.optimizer.eliminate_joins import eliminate_joins +from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries +from sqlglot.optimizer.normalize import normalize +from sqlglot.optimizer.pushdown_projections import pushdown_projections +from sqlglot.optimizer.qualify import qualify +from sqlglot.optimizer.unnest_subqueries import unnest_subqueries + +from datapilot.core.insights.sql.base.insight import SqlInsight +from datapilot.core.insights.utils import get_severity +from datapilot.core.platforms.dbt.insights.schema import DBTInsightResult +from datapilot.core.platforms.dbt.insights.schema import DBTModelInsightResponse + +RULES = ( + pushdown_projections, + normalize, + unnest_subqueries, + eliminate_subqueries, + eliminate_joins, + eliminate_ctes, +) + + +class SqlCheck(SqlInsight): + """ + This class identifies DBT models with SQL optimization issues. + """ + + NAME = "sql optimization issues" + ALIAS = "check_sql_optimization" + DESCRIPTION = "Checks if the model has SQL optimization issues. " + REASON_TO_FLAG = "The query can be optimized." + FAILURE_MESSAGE = "The query for model `{model_unique_id}` has optimization opportunities:\n{rule_name}. " + RECOMMENDATION = "Please adapt the query of the model `{model_unique_id}` as in following example:\n{optimized_sql}" + + def _build_failure_result(self, model_unique_id: str, rule_name: str, optimized_sql: str) -> DBTInsightResult: + """ + Constructs a failure result for a given model with sql optimization issues. + :param model_unique_id: The unique id of the dbt model. + :param rule_name: The rule that generated this failure result. + :param optimized_sql: The optimized sql. + :return: An instance of DBTInsightResult containing failure details. + """ + failure_message = self.FAILURE_MESSAGE.format(model_unique_id=model_unique_id, rule_name=rule_name) + recommendation = self.RECOMMENDATION.format(model_unique_id=model_unique_id, optimized_sql=optimized_sql) + return DBTInsightResult( + type=self.TYPE, + name=self.NAME, + message=failure_message, + recommendation=recommendation, + reason_to_flag=self.REASON_TO_FLAG, + metadata={"model_unique_id": model_unique_id, "rule_name": rule_name}, + ) + + def generate(self, *args, **kwargs) -> List[DBTModelInsightResponse]: + """ + Generates insights for each DBT model in the project, focusing on sql optimization issues. + + :return: A list of DBTModelInsightResponse objects with insights for each model. + """ + self.logger.debug("Generating sql insights for DBT models") + insights = [] + + possible_kwargs = { + "db": None, + "catalog": None, + "dialect": self.adapter_type, + "isolate_tables": True, # needed for other optimizations to perform well + "quote_identifiers": False, + **kwargs, + } + for node_id, node in self.nodes.items(): + try: + compiled_query = node.compiled_code + if compiled_query: + parsed_query = parse_one(compiled_query, dialect=self.adapter_type) + qualified = qualify(parsed_query, **possible_kwargs) + changed = qualified.copy() + for rule in RULES: + original = changed.copy() + rule_params = inspect.getfullargspec(rule).args + rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} + changed = rule(changed, **rule_kwargs) + if changed.sql() != original.sql(): + insights.append( + DBTModelInsightResponse( + unique_id=node_id, + package_name=node.package_name, + path=node.original_file_path, + original_file_path=node.original_file_path, + insight=self._build_failure_result(node_id, rule.__name__, changed.sql()), + severity=get_severity(self.config, self.ALIAS, self.DEFAULT_SEVERITY), + ) + ) + except Exception as e: + self.logger.error(e) + return insights diff --git a/src/datapilot/core/platforms/dbt/wrappers/manifest/v10/wrapper.py b/src/datapilot/core/platforms/dbt/wrappers/manifest/v10/wrapper.py index bc74152a..8addc8d3 100644 --- a/src/datapilot/core/platforms/dbt/wrappers/manifest/v10/wrapper.py +++ b/src/datapilot/core/platforms/dbt/wrappers/manifest/v10/wrapper.py @@ -1,4 +1,5 @@ from typing import Dict +from typing import Optional from typing import Set from dbt_artifacts_parser.parsers.manifest.manifest_v10 import GenericTestNode @@ -67,6 +68,7 @@ def _get_node(self, node: ManifestNode) -> AltimateManifestNode: depends_on_macros = node.depends_on.macros if node.depends_on else None compiled_path = node.compiled_path compiled = node.compiled + compiled_code = node.compiled_code raw_code = node.raw_code language = node.language contract = AltimateDBTContract(**node.contract.__dict__) if node.contract else None @@ -381,6 +383,9 @@ def get_seeds(self) -> Dict[str, AltimateSeedNode]: seeds[seed.unique_id] = self._get_seed(seed) return seeds + def get_adapter_type(self) -> Optional[str]: + return self.manifest.metadata.adapter_type + def parent_to_child_map(self, nodes: Dict[str, AltimateManifestNode]) -> Dict[str, Set[str]]: """ Current manifest contains information about parents diff --git a/src/datapilot/core/platforms/dbt/wrappers/manifest/v11/wrapper.py b/src/datapilot/core/platforms/dbt/wrappers/manifest/v11/wrapper.py index b7eefe7a..c9d04731 100644 --- a/src/datapilot/core/platforms/dbt/wrappers/manifest/v11/wrapper.py +++ b/src/datapilot/core/platforms/dbt/wrappers/manifest/v11/wrapper.py @@ -1,4 +1,5 @@ from typing import Dict +from typing import Optional from typing import Set from dbt_artifacts_parser.parsers.manifest.manifest_v11 import GenericTestNode @@ -67,6 +68,7 @@ def _get_node(self, node: ManifestNode) -> AltimateManifestNode: depends_on_macros = node.depends_on.macros if node.depends_on else None compiled_path = node.compiled_path compiled = node.compiled + compiled_code = node.compiled_code raw_code = node.raw_code language = node.language contract = AltimateDBTContract(**node.contract.__dict__) if node.contract else None @@ -381,6 +383,9 @@ def get_seeds(self) -> Dict[str, AltimateSeedNode]: seeds[seed.unique_id] = self._get_seed(seed) return seeds + def get_adapter_type(self) -> Optional[str]: + return self.manifest.metadata.adapter_type + def parent_to_child_map(self, nodes: Dict[str, AltimateManifestNode]) -> Dict[str, Set[str]]: """ Current manifest contains information about parents diff --git a/src/datapilot/core/platforms/dbt/wrappers/manifest/v12/wrapper.py b/src/datapilot/core/platforms/dbt/wrappers/manifest/v12/wrapper.py index 3c674b01..d694f514 100644 --- a/src/datapilot/core/platforms/dbt/wrappers/manifest/v12/wrapper.py +++ b/src/datapilot/core/platforms/dbt/wrappers/manifest/v12/wrapper.py @@ -1,4 +1,5 @@ from typing import Dict +from typing import Optional from typing import Set from dbt_artifacts_parser.parsers.manifest.manifest_v12 import ManifestV12 @@ -67,6 +68,7 @@ def _get_node(self, node: ManifestNode) -> AltimateManifestNode: depends_on_macros = node.depends_on.macros if node.depends_on else None compiled_path = node.compiled_path compiled = node.compiled + compiled_code = node.compiled_code raw_code = node.raw_code language = node.language contract = AltimateDBTContract(**node.contract.__dict__) if node.contract else None @@ -393,6 +395,9 @@ def get_seeds(self) -> Dict[str, AltimateSeedNode]: seeds[seed.unique_id] = self._get_seed(seed) return seeds + def get_adapter_type(self) -> Optional[str]: + return self.manifest.metadata.adapter_type + def parent_to_child_map(self, nodes: Dict[str, AltimateManifestNode]) -> Dict[str, Set[str]]: """ Current manifest contains information about parents diff --git a/src/datapilot/core/platforms/dbt/wrappers/manifest/wrapper.py b/src/datapilot/core/platforms/dbt/wrappers/manifest/wrapper.py index ee3ec769..bb402cda 100644 --- a/src/datapilot/core/platforms/dbt/wrappers/manifest/wrapper.py +++ b/src/datapilot/core/platforms/dbt/wrappers/manifest/wrapper.py @@ -1,6 +1,7 @@ from abc import ABC from abc import abstractmethod from typing import Dict +from typing import Optional from typing import Set from datapilot.core.platforms.dbt.schemas.manifest import AltimateManifestExposureNode @@ -26,6 +27,10 @@ def get_sources(self) -> Dict[str, AltimateManifestSourceNode]: def get_exposures(self) -> Dict[str, AltimateManifestExposureNode]: pass + @abstractmethod + def get_adapter_type(self) -> Optional[str]: + pass + @abstractmethod def parent_to_child_map(self, nodes: Dict[str, AltimateManifestNode]) -> Dict[str, Set[str]]: pass