Skip to content

Commit 9e59bf9

Browse files
mayurinehateyoonhyejin
authored andcommitted
feat(classification): allow parallelisation to reduce time (#8368)
1 parent bf5f6f5 commit 9e59bf9

File tree

9 files changed

+232
-57
lines changed

9 files changed

+232
-57
lines changed

metadata-ingestion/docs/dev_guides/classification.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Note that a `.` is used to denote nested fields in the YAML recipe.
1010
| ------------------------- | -------- | --------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------- |
1111
| enabled | | boolean | Whether classification should be used to auto-detect glossary terms | False |
1212
| sample_size | | int | Number of sample values used for classification. | 100 |
13+
| max_workers | | int | Number of worker threads to use for classification. Set to 1 to disable. | Number of cpu cores or 4 |
1314
| info_type_to_term | | Dict[str,string] | Optional mapping to provide glossary term identifier for info type. | By default, info type is used as glossary term identifier. |
1415
| classifiers | | Array of object | Classifiers to use to auto-detect glossary terms. If more than one classifier, infotype predictions from the classifier defined later in sequence take precedance. | [{'type': 'datahub', 'config': None}] |
1516
| table_pattern | | AllowDenyPattern (see below for fields) | Regex patterns to filter tables for classification. This is used in combination with other patterns in parent config. Specify regex to match the entire table name in `database.schema.table` format. e.g. to match all tables starting with customer in Customer database and public schema, use the regex 'Customer.public.customer.*' | {'allow': ['.*'], 'deny': [], 'ignoreCase': True} |

metadata-ingestion/src/datahub/ingestion/glossary/classification_mixin.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import concurrent.futures
12
import logging
23
from dataclasses import dataclass, field
3-
from typing import Dict, List
4+
from math import ceil
5+
from typing import Dict, Iterable, List, Optional
46

57
from datahub_classify.helper_classes import ColumnInfo, Metadata
68
from pydantic import Field
@@ -108,15 +110,23 @@ def classify_schema_fields(
108110
return None
109111

110112
logger.debug(f"Classifying Table {dataset_name}")
113+
111114
self.report.num_tables_classification_attempted += 1
112115
field_terms: Dict[str, str] = {}
113116
with PerfTimer() as timer:
114117
try:
115118
for classifier in self.classifiers:
116-
column_info_with_proposals = classifier.classify(column_infos)
117-
self.extract_field_wise_terms(
118-
field_terms, column_info_with_proposals
119-
)
119+
column_infos_with_proposals: Iterable[ColumnInfo]
120+
if self.config.classification.max_workers > 1:
121+
column_infos_with_proposals = self.async_classify(
122+
classifier, column_infos
123+
)
124+
else:
125+
column_infos_with_proposals = classifier.classify(column_infos)
126+
127+
for column_info_proposal in column_infos_with_proposals:
128+
self.update_field_terms(field_terms, column_info_proposal)
129+
120130
except Exception:
121131
self.report.num_tables_classification_failed += 1
122132
raise
@@ -130,6 +140,44 @@ def classify_schema_fields(
130140
self.report.num_tables_classified += 1
131141
self.populate_terms_in_schema_metadata(schema_metadata, field_terms)
132142

143+
def update_field_terms(
144+
self, field_terms: Dict[str, str], col_info: ColumnInfo
145+
) -> None:
146+
term = self.get_terms_for_column(col_info)
147+
if term:
148+
field_terms[col_info.metadata.name] = term
149+
150+
def async_classify(
151+
self, classifier: Classifier, columns: List[ColumnInfo]
152+
) -> Iterable[ColumnInfo]:
153+
num_columns = len(columns)
154+
BATCH_SIZE = 5 # Number of columns passed to classify api at a time
155+
156+
logger.debug(
157+
f"Will Classify {num_columns} column(s) with {self.config.classification.max_workers} worker(s) with batch size {BATCH_SIZE}."
158+
)
159+
160+
with concurrent.futures.ProcessPoolExecutor(
161+
max_workers=self.config.classification.max_workers,
162+
) as executor:
163+
column_info_proposal_futures = [
164+
executor.submit(
165+
classifier.classify,
166+
columns[
167+
(i * BATCH_SIZE) : min(i * BATCH_SIZE + BATCH_SIZE, num_columns)
168+
],
169+
)
170+
for i in range(ceil(num_columns / BATCH_SIZE))
171+
]
172+
173+
return [
174+
column_with_proposal
175+
for proposal_future in concurrent.futures.as_completed(
176+
column_info_proposal_futures
177+
)
178+
for column_with_proposal in proposal_future.result()
179+
]
180+
133181
def populate_terms_in_schema_metadata(
134182
self,
135183
schema_metadata: SchemaMetadata,
@@ -154,25 +202,20 @@ def populate_terms_in_schema_metadata(
154202
),
155203
)
156204

157-
def extract_field_wise_terms(
158-
self,
159-
field_terms: Dict[str, str],
160-
column_info_with_proposals: List[ColumnInfo],
161-
) -> None:
162-
for col_info in column_info_with_proposals:
163-
if not col_info.infotype_proposals:
164-
continue
165-
infotype_proposal = max(
166-
col_info.infotype_proposals, key=lambda p: p.confidence_level
167-
)
168-
self.report.info_types_detected.setdefault(
169-
infotype_proposal.infotype, LossyList()
170-
).append(f"{col_info.metadata.dataset_name}.{col_info.metadata.name}")
171-
field_terms[
172-
col_info.metadata.name
173-
] = self.config.classification.info_type_to_term.get(
174-
infotype_proposal.infotype, infotype_proposal.infotype
175-
)
205+
def get_terms_for_column(self, col_info: ColumnInfo) -> Optional[str]:
206+
if not col_info.infotype_proposals:
207+
return None
208+
infotype_proposal = max(
209+
col_info.infotype_proposals, key=lambda p: p.confidence_level
210+
)
211+
self.report.info_types_detected.setdefault(
212+
infotype_proposal.infotype, LossyList()
213+
).append(f"{col_info.metadata.dataset_name}.{col_info.metadata.name}")
214+
term = self.config.classification.info_type_to_term.get(
215+
infotype_proposal.infotype, infotype_proposal.infotype
216+
)
217+
218+
return term
176219

177220
def get_columns_to_classify(
178221
self,

metadata-ingestion/src/datahub/ingestion/glossary/classifier.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from abc import ABCMeta, abstractmethod
23
from dataclasses import dataclass
34
from typing import Any, Dict, List, Optional
@@ -36,6 +37,11 @@ class ClassificationConfig(ConfigModel):
3637
default=100, description="Number of sample values used for classification."
3738
)
3839

40+
max_workers: int = Field(
41+
default=(os.cpu_count() or 4),
42+
description="Number of worker threads to use for classification. Set to 1 to disable.",
43+
)
44+
3945
table_pattern: AllowDenyPattern = Field(
4046
default=AllowDenyPattern.allow_all(),
4147
description="Regex patterns to filter tables for classification. This is used in combination with other patterns in parent config. Specify regex to match the entire table name in `database.schema.table` format. e.g. to match all tables starting with customer in Customer database and public schema, use the regex 'Customer.public.customer.*'",

metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,5 @@ def classify(self, columns: List[ColumnInfo]) -> List[ColumnInfo]:
173173
infotypes=self.config.info_types,
174174
minimum_values_threshold=self.config.minimum_values_threshold,
175175
)
176+
176177
return columns

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -872,8 +872,8 @@ def _process_table(
872872
self.gen_schema_metadata(table, schema_name, db_name)
873873

874874
def fetch_sample_data_for_classification(
875-
self, table, schema_name, db_name, dataset_name
876-
):
875+
self, table: SnowflakeTable, schema_name: str, db_name: str, dataset_name: str
876+
) -> None:
877877
if (
878878
table.columns
879879
and self.config.classification.enabled
@@ -1225,7 +1225,12 @@ def build_foreign_keys(self, table, dataset_urn, foreign_keys):
12251225
)
12261226
return foreign_keys
12271227

1228-
def classify_snowflake_table(self, table, dataset_name, schema_metadata):
1228+
def classify_snowflake_table(
1229+
self,
1230+
table: Union[SnowflakeTable, SnowflakeView],
1231+
dataset_name: str,
1232+
schema_metadata: SchemaMetadata,
1233+
) -> None:
12291234
if (
12301235
isinstance(table, SnowflakeTable)
12311236
and self.config.classification.enabled
@@ -1255,6 +1260,9 @@ def classify_snowflake_table(self, table, dataset_name, schema_metadata):
12551260
"Failed to classify table columns",
12561261
dataset_name,
12571262
)
1263+
finally:
1264+
# Cleaning up sample_data fetched for classification
1265+
table.sample_data = None
12581266

12591267
def get_report(self) -> SourceReport:
12601268
return self.report
@@ -1470,7 +1478,7 @@ def get_sample_values_for_table(self, table_name, schema_name, db_name):
14701478
df = pd.DataFrame(dat, columns=[col.name for col in cur.description])
14711479
time_taken = timer.elapsed_seconds()
14721480
logger.debug(
1473-
f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name}; took {time_taken:.3f} seconds"
1481+
f"Finished collecting sample values for table {db_name}.{schema_name}.{table_name};{df.shape[0]} rows; took {time_taken:.3f} seconds"
14741482
)
14751483

14761484
return df

metadata-ingestion/tests/integration/snowflake/common.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
FROZEN_TIME = "2022-06-07 17:00:00"
1515

1616

17-
def default_query_results(query): # noqa: C901
17+
def default_query_results( # noqa: C901
18+
query,
19+
num_tables=NUM_TABLES,
20+
num_views=NUM_VIEWS,
21+
num_cols=NUM_COLS,
22+
num_ops=NUM_OPS,
23+
):
1824
if query == SnowflakeQuery.current_account():
1925
return [{"CURRENT_ACCOUNT()": "ABC12345"}]
2026
if query == SnowflakeQuery.current_region():
@@ -79,7 +85,7 @@ def default_query_results(query): # noqa: C901
7985
"COMMENT": "Comment for Table",
8086
"CLUSTERING_KEY": None,
8187
}
82-
for tbl_idx in range(1, NUM_TABLES + 1)
88+
for tbl_idx in range(1, num_tables + 1)
8389
]
8490
elif query == SnowflakeQuery.show_views_for_schema("TEST_SCHEMA", "TEST_DB"):
8591
return [
@@ -90,7 +96,7 @@ def default_query_results(query): # noqa: C901
9096
"comment": "Comment for View",
9197
"text": None,
9298
}
93-
for view_idx in range(1, NUM_VIEWS + 1)
99+
for view_idx in range(1, num_views + 1)
94100
]
95101
elif query == SnowflakeQuery.columns_for_schema("TEST_SCHEMA", "TEST_DB"):
96102
raise Exception("Information schema query returned too much data")
@@ -99,13 +105,13 @@ def default_query_results(query): # noqa: C901
99105
SnowflakeQuery.columns_for_table(
100106
"TABLE_{}".format(tbl_idx), "TEST_SCHEMA", "TEST_DB"
101107
)
102-
for tbl_idx in range(1, NUM_TABLES + 1)
108+
for tbl_idx in range(1, num_tables + 1)
103109
],
104110
*[
105111
SnowflakeQuery.columns_for_table(
106112
"VIEW_{}".format(view_idx), "TEST_SCHEMA", "TEST_DB"
107113
)
108-
for view_idx in range(1, NUM_VIEWS + 1)
114+
for view_idx in range(1, num_views + 1)
109115
],
110116
]:
111117
return [
@@ -122,7 +128,7 @@ def default_query_results(query): # noqa: C901
122128
"NUMERIC_PRECISION": None if col_idx > 1 else 38,
123129
"NUMERIC_SCALE": None if col_idx > 1 else 0,
124130
}
125-
for col_idx in range(1, NUM_COLS + 1)
131+
for col_idx in range(1, num_cols + 1)
126132
]
127133
elif query in (
128134
SnowflakeQuery.use_database("TEST_DB"),
@@ -158,7 +164,7 @@ def default_query_results(query): # noqa: C901
158164
{
159165
"columns": [
160166
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
161-
for col_idx in range(1, NUM_COLS + 1)
167+
for col_idx in range(1, num_cols + 1)
162168
],
163169
"objectDomain": "Table",
164170
"objectId": 0,
@@ -167,7 +173,7 @@ def default_query_results(query): # noqa: C901
167173
{
168174
"columns": [
169175
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
170-
for col_idx in range(1, NUM_COLS + 1)
176+
for col_idx in range(1, num_cols + 1)
171177
],
172178
"objectDomain": "Table",
173179
"objectId": 0,
@@ -176,7 +182,7 @@ def default_query_results(query): # noqa: C901
176182
{
177183
"columns": [
178184
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
179-
for col_idx in range(1, NUM_COLS + 1)
185+
for col_idx in range(1, num_cols + 1)
180186
],
181187
"objectDomain": "Table",
182188
"objectId": 0,
@@ -189,7 +195,7 @@ def default_query_results(query): # noqa: C901
189195
{
190196
"columns": [
191197
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
192-
for col_idx in range(1, NUM_COLS + 1)
198+
for col_idx in range(1, num_cols + 1)
193199
],
194200
"objectDomain": "Table",
195201
"objectId": 0,
@@ -198,7 +204,7 @@ def default_query_results(query): # noqa: C901
198204
{
199205
"columns": [
200206
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
201-
for col_idx in range(1, NUM_COLS + 1)
207+
for col_idx in range(1, num_cols + 1)
202208
],
203209
"objectDomain": "Table",
204210
"objectId": 0,
@@ -207,7 +213,7 @@ def default_query_results(query): # noqa: C901
207213
{
208214
"columns": [
209215
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
210-
for col_idx in range(1, NUM_COLS + 1)
216+
for col_idx in range(1, num_cols + 1)
211217
],
212218
"objectDomain": "Table",
213219
"objectId": 0,
@@ -231,7 +237,7 @@ def default_query_results(query): # noqa: C901
231237
}
232238
],
233239
}
234-
for col_idx in range(1, NUM_COLS + 1)
240+
for col_idx in range(1, num_cols + 1)
235241
],
236242
"objectDomain": "Table",
237243
"objectId": 0,
@@ -246,7 +252,7 @@ def default_query_results(query): # noqa: C901
246252
"EMAIL": "[email protected]",
247253
"ROLE_NAME": "ACCOUNTADMIN",
248254
}
249-
for op_idx in range(1, NUM_OPS + 1)
255+
for op_idx in range(1, num_ops + 1)
250256
]
251257
elif (
252258
query
@@ -276,7 +282,7 @@ def default_query_results(query): # noqa: C901
276282
"UPSTREAM_TABLE_COLUMNS": json.dumps(
277283
[
278284
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
279-
for col_idx in range(1, NUM_COLS + 1)
285+
for col_idx in range(1, num_cols + 1)
280286
]
281287
),
282288
"DOWNSTREAM_TABLE_COLUMNS": json.dumps(
@@ -293,11 +299,11 @@ def default_query_results(query): # noqa: C901
293299
}
294300
],
295301
}
296-
for col_idx in range(1, NUM_COLS + 1)
302+
for col_idx in range(1, num_cols + 1)
297303
]
298304
),
299305
}
300-
for op_idx in range(1, NUM_OPS + 1)
306+
for op_idx in range(1, num_ops + 1)
301307
] + [
302308
{
303309
"DOWNSTREAM_TABLE_NAME": "TEST_DB.TEST_SCHEMA.TABLE_1",
@@ -371,7 +377,7 @@ def default_query_results(query): # noqa: C901
371377
]
372378
],
373379
}
374-
for col_idx in range(1, NUM_COLS + 1)
380+
for col_idx in range(1, num_cols + 1)
375381
]
376382
+ ( # This additional upstream is only for TABLE_1
377383
[
@@ -393,7 +399,7 @@ def default_query_results(query): # noqa: C901
393399
)
394400
),
395401
}
396-
for op_idx in range(1, NUM_OPS + 1)
402+
for op_idx in range(1, num_ops + 1)
397403
]
398404
elif query in (
399405
snowflake_query.SnowflakeQuery.table_to_table_lineage_history_v2(
@@ -426,7 +432,7 @@ def default_query_results(query): # noqa: C901
426432
)
427433
),
428434
}
429-
for op_idx in range(1, NUM_OPS + 1)
435+
for op_idx in range(1, num_ops + 1)
430436
]
431437
elif query == snowflake_query.SnowflakeQuery.external_table_lineage_history(
432438
1654499820000,
@@ -479,7 +485,7 @@ def default_query_results(query): # noqa: C901
479485
"VIEW_COLUMNS": json.dumps(
480486
[
481487
{"columnId": 0, "columnName": "COL_{}".format(col_idx)}
482-
for col_idx in range(1, NUM_COLS + 1)
488+
for col_idx in range(1, num_cols + 1)
483489
]
484490
),
485491
"DOWNSTREAM_TABLE_DOMAIN": "TABLE",
@@ -497,7 +503,7 @@ def default_query_results(query): # noqa: C901
497503
}
498504
],
499505
}
500-
for col_idx in range(1, NUM_COLS + 1)
506+
for col_idx in range(1, num_cols + 1)
501507
]
502508
),
503509
}

0 commit comments

Comments
 (0)