|
1 | 1 | from dataclasses import dataclass |
2 | 2 | from datetime import datetime |
| 3 | +import hashlib |
3 | 4 | import json |
4 | 5 | import threading |
5 | 6 | from multiprocessing.context import SpawnContext |
|
53 | 54 | import google.oauth2 |
54 | 55 | import google.cloud.bigquery |
55 | 56 | from google.cloud.bigquery import AccessEntry, SchemaField, Table as BigQueryTable |
| 57 | +from google.cloud import dataplex_v1 |
56 | 58 | import google.cloud.exceptions |
| 59 | +from google.protobuf import field_mask_pb2 |
57 | 60 | import pytz |
58 | 61 |
|
59 | 62 | from dbt.adapters.bigquery import BigQueryColumn, BigQueryConnectionManager |
@@ -96,6 +99,26 @@ def render(self): |
96 | 99 | return f"{self.project}.{self.dataset}" |
97 | 100 |
|
98 | 101 |
|
| 102 | +@dataclass |
| 103 | +class DataProfileScanSetting: |
| 104 | + location: str |
| 105 | + scan_id: Optional[str] |
| 106 | + |
| 107 | + project_id: str |
| 108 | + dataset_id: str |
| 109 | + table_id: str |
| 110 | + |
| 111 | + sampling_percent: Optional[float] |
| 112 | + row_filter: Optional[str] |
| 113 | + cron: Optional[str] |
| 114 | + |
| 115 | + def parent(self): |
| 116 | + return f"projects/{self.project_id}/locations/{self.location}" |
| 117 | + |
| 118 | + def data_scan_name(self): |
| 119 | + return f"{self.parent()}/dataScans/{self.scan_id}" |
| 120 | + |
| 121 | + |
99 | 122 | def _stub_relation(*args, **kwargs): |
100 | 123 | return BigQueryRelation.create( |
101 | 124 | database="", schema="", identifier="", quote_policy={}, type=BigQueryRelation.Table |
@@ -999,3 +1022,142 @@ def validate_sql(self, sql: str) -> AdapterResponse: |
999 | 1022 | :param str sql: The sql to validate |
1000 | 1023 | """ |
1001 | 1024 | return self.connections.dry_run(sql) |
| 1025 | + |
| 1026 | + # If the label `dataplex-dp-published-*` is not assigned, we cannot view the results of the Data Profile Scan from BigQuery |
| 1027 | + def _update_labels_with_data_profile_scan_labels( |
| 1028 | + self, |
| 1029 | + project_id: str, |
| 1030 | + dataset_id: str, |
| 1031 | + table_id: str, |
| 1032 | + location: str, |
| 1033 | + scan_id: str, |
| 1034 | + ): |
| 1035 | + table = self.connections.get_bq_table(project_id, dataset_id, table_id) |
| 1036 | + original_labels = table.labels |
| 1037 | + profile_scan_labels = { |
| 1038 | + "dataplex-dp-published-scan": scan_id, |
| 1039 | + "dataplex-dp-published-project": project_id, |
| 1040 | + "dataplex-dp-published-location": location, |
| 1041 | + } |
| 1042 | + table.labels = {**original_labels, **profile_scan_labels} |
| 1043 | + self.connections.get_thread_connection().handle.update_table(table, ["labels"]) |
| 1044 | + |
| 1045 | + # scan_id must be unique within the project and no longer than 63 characters, |
| 1046 | + # so generate an id that meets the constraints |
| 1047 | + def _generate_unique_scan_id(self, dataset_id: str, table_id: str) -> str: |
| 1048 | + md5 = hashlib.md5(f"{dataset_id}_{table_id}".encode("utf-8")).hexdigest() |
| 1049 | + return f"dbt-{table_id.replace('_', '-')}-{md5}"[:63] |
| 1050 | + |
| 1051 | + def _create_or_update_data_profile_scan( |
| 1052 | + self, |
| 1053 | + client: dataplex_v1.DataScanServiceClient, |
| 1054 | + scan_setting: DataProfileScanSetting, |
| 1055 | + ): |
| 1056 | + data_profile_spec = dataplex_v1.DataProfileSpec( |
| 1057 | + sampling_percent=scan_setting.sampling_percent, |
| 1058 | + row_filter=scan_setting.row_filter, |
| 1059 | + ) |
| 1060 | + display_name = ( |
| 1061 | + f"Data Profile Scan for {scan_setting.table_id} in {scan_setting.dataset_id}" |
| 1062 | + ) |
| 1063 | + description = f"This is a Data Profile Scan for {scan_setting.project_id}.{scan_setting.dataset_id}.{scan_setting.table_id}. Created by dbt." |
| 1064 | + labels = { |
| 1065 | + "managed_by": "dbt", |
| 1066 | + } |
| 1067 | + |
| 1068 | + if scan_setting.cron: |
| 1069 | + trigger = dataplex_v1.Trigger( |
| 1070 | + schedule=dataplex_v1.Trigger.Schedule(cron=scan_setting.cron) |
| 1071 | + ) |
| 1072 | + else: |
| 1073 | + trigger = dataplex_v1.Trigger(on_demand=dataplex_v1.Trigger.OnDemand()) |
| 1074 | + execution_spec = dataplex_v1.DataScan.ExecutionSpec(trigger=trigger) |
| 1075 | + |
| 1076 | + if all( |
| 1077 | + scan.name != scan_setting.data_scan_name() |
| 1078 | + for scan in client.list_data_scans(parent=scan_setting.parent()) |
| 1079 | + ): |
| 1080 | + data_scan = dataplex_v1.DataScan( |
| 1081 | + data=dataplex_v1.DataSource( |
| 1082 | + resource=f"//bigquery.googleapis.com/projects/{scan_setting.project_id}/datasets/{scan_setting.dataset_id}/tables/{scan_setting.table_id}" |
| 1083 | + ), |
| 1084 | + data_profile_spec=data_profile_spec, |
| 1085 | + execution_spec=execution_spec, |
| 1086 | + display_name=display_name, |
| 1087 | + description=description, |
| 1088 | + labels=labels, |
| 1089 | + ) |
| 1090 | + request = dataplex_v1.CreateDataScanRequest( |
| 1091 | + parent=scan_setting.parent(), |
| 1092 | + data_scan_id=scan_setting.scan_id, |
| 1093 | + data_scan=data_scan, |
| 1094 | + ) |
| 1095 | + client.create_data_scan(request=request).result() |
| 1096 | + else: |
| 1097 | + request = dataplex_v1.GetDataScanRequest( |
| 1098 | + name=scan_setting.data_scan_name(), |
| 1099 | + ) |
| 1100 | + data_scan = client.get_data_scan(request=request) |
| 1101 | + |
| 1102 | + data_scan.data_profile_spec = data_profile_spec |
| 1103 | + data_scan.execution_spec = execution_spec |
| 1104 | + data_scan.display_name = display_name |
| 1105 | + data_scan.description = description |
| 1106 | + data_scan.labels = labels |
| 1107 | + |
| 1108 | + update_mask = field_mask_pb2.FieldMask( |
| 1109 | + paths=[ |
| 1110 | + "data_profile_spec", |
| 1111 | + "execution_spec", |
| 1112 | + "display_name", |
| 1113 | + "description", |
| 1114 | + "labels", |
| 1115 | + ] |
| 1116 | + ) |
| 1117 | + request = dataplex_v1.UpdateDataScanRequest( |
| 1118 | + data_scan=data_scan, |
| 1119 | + update_mask=update_mask, |
| 1120 | + ) |
| 1121 | + client.update_data_scan(request=request).result() |
| 1122 | + |
| 1123 | + @available |
| 1124 | + def create_or_update_data_profile_scan(self, config): |
| 1125 | + project_id = config.get("database") |
| 1126 | + dataset_id = config.get("schema") |
| 1127 | + table_id = config.get("name") |
| 1128 | + |
| 1129 | + data_profile_config = config.get("config").get("data_profile_scan", {}) |
| 1130 | + |
| 1131 | + # Skip if data_profile_scan is not configured |
| 1132 | + if not data_profile_config: |
| 1133 | + return None |
| 1134 | + |
| 1135 | + client = dataplex_v1.DataScanServiceClient() |
| 1136 | + scan_setting = DataProfileScanSetting( |
| 1137 | + location=data_profile_config["location"], |
| 1138 | + scan_id=data_profile_config.get( |
| 1139 | + "scan_id", self._generate_unique_scan_id(dataset_id, table_id) |
| 1140 | + ), |
| 1141 | + project_id=project_id, |
| 1142 | + dataset_id=dataset_id, |
| 1143 | + table_id=table_id, |
| 1144 | + sampling_percent=data_profile_config.get("sampling_percent", None), |
| 1145 | + row_filter=data_profile_config.get("row_filter", None), |
| 1146 | + cron=data_profile_config.get("cron", None), |
| 1147 | + ) |
| 1148 | + |
| 1149 | + # Delete existing data profile scan if it is disabled |
| 1150 | + if not data_profile_config.get("enabled", True): |
| 1151 | + client.delete_data_scan(name=scan_setting.data_scan_name()) |
| 1152 | + return None |
| 1153 | + |
| 1154 | + self._create_or_update_data_profile_scan(client, scan_setting) |
| 1155 | + |
| 1156 | + if not scan_setting.cron: |
| 1157 | + client.run_data_scan( |
| 1158 | + request=dataplex_v1.RunDataScanRequest(name=scan_setting.data_scan_name()) |
| 1159 | + ) |
| 1160 | + |
| 1161 | + self._update_labels_with_data_profile_scan_labels( |
| 1162 | + project_id, dataset_id, table_id, scan_setting.location, scan_setting.scan_id |
| 1163 | + ) |
0 commit comments