Skip to content

Commit 680d16f

Browse files
feat(ingestion/sql-queries): add performance optimizations, S3 support, temp table patterns (#14757)
Co-authored-by: Sergio Gómez Villamor <[email protected]>
1 parent f05f3e4 commit 680d16f

File tree

15 files changed

+2458
-266
lines changed

15 files changed

+2458
-266
lines changed

metadata-ingestion/src/datahub/ingestion/graph/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
from datahub.sql_parsing.schema_resolver import (
103103
GraphQLSchemaMetadata,
104104
SchemaResolver,
105+
SchemaResolverReport,
105106
)
106107
from datahub.sql_parsing.sqlglot_lineage import SqlParsingResult
107108

@@ -1543,6 +1544,7 @@ def _make_schema_resolver(
15431544
platform_instance: Optional[str],
15441545
env: str,
15451546
include_graph: bool = True,
1547+
report: Optional["SchemaResolverReport"] = None,
15461548
) -> "SchemaResolver":
15471549
from datahub.sql_parsing.schema_resolver import SchemaResolver
15481550

@@ -1551,6 +1553,7 @@ def _make_schema_resolver(
15511553
platform_instance=platform_instance,
15521554
env=env,
15531555
graph=self if include_graph else None,
1556+
report=report,
15541557
)
15551558

15561559
def initialize_schema_resolver_from_datahub(
@@ -1559,10 +1562,11 @@ def initialize_schema_resolver_from_datahub(
15591562
platform_instance: Optional[str],
15601563
env: str,
15611564
batch_size: int = 100,
1565+
report: Optional["SchemaResolverReport"] = None,
15621566
) -> "SchemaResolver":
15631567
logger.info("Initializing schema resolver")
15641568
schema_resolver = self._make_schema_resolver(
1565-
platform, platform_instance, env, include_graph=False
1569+
platform, platform_instance, env, include_graph=False, report=report
15661570
)
15671571

15681572
logger.info(f"Fetching schemas for platform {platform}, env {env}")

metadata-ingestion/src/datahub/ingestion/run/pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ def run(self) -> None:
558558

559559
self.process_commits()
560560
self.final_status = PipelineStatus.COMPLETED
561+
561562
except (SystemExit, KeyboardInterrupt):
562563
self.final_status = PipelineStatus.CANCELLED
563564
logger.error("Caught error", exc_info=True)

metadata-ingestion/src/datahub/ingestion/source/sql_queries.py

Lines changed: 164 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import json
22
import logging
33
import os
4-
from dataclasses import dataclass
4+
import re
5+
from dataclasses import dataclass, field
56
from datetime import datetime
67
from functools import partial
7-
from typing import ClassVar, Iterable, List, Optional, Union
8+
from typing import ClassVar, Iterable, List, Optional, Union, cast
89

10+
import smart_open
911
from pydantic import BaseModel, Field, validator
1012

1113
from datahub.configuration.common import HiddenFromDocs
@@ -36,12 +38,13 @@
3638
SourceCapability,
3739
SourceReport,
3840
)
39-
from datahub.ingestion.api.source_helpers import auto_workunit_reporter
41+
from datahub.ingestion.api.source_helpers import auto_workunit, auto_workunit_reporter
4042
from datahub.ingestion.api.workunit import MetadataWorkUnit
4143
from datahub.ingestion.graph.client import DataHubGraph
44+
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
4245
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
4346
from datahub.metadata.urns import CorpUserUrn, DatasetUrn
44-
from datahub.sql_parsing.schema_resolver import SchemaResolver
47+
from datahub.sql_parsing.schema_resolver import SchemaResolver, SchemaResolverReport
4548
from datahub.sql_parsing.sql_parsing_aggregator import (
4649
KnownQueryLineageInfo,
4750
ObservedQuery,
@@ -82,15 +85,38 @@ class SqlQueriesSourceConfig(
8285
None,
8386
description="The SQL dialect to use when parsing queries. Overrides automatic dialect detection.",
8487
)
88+
temp_table_patterns: List[str] = Field(
89+
description="Regex patterns for temporary tables to filter in lineage ingestion. "
90+
"Specify regex to match the entire table name. This is useful for platforms like Athena "
91+
"that don't have native temp tables but use naming patterns for fake temp tables.",
92+
default=[],
93+
)
94+
95+
enable_lazy_schema_loading: bool = Field(
96+
default=True,
97+
description="Enable lazy schema loading for better performance. When enabled, schemas are fetched on-demand "
98+
"instead of bulk loading all schemas upfront, reducing startup time and memory usage.",
99+
)
100+
101+
# AWS/S3 configuration
102+
aws_config: Optional[AwsConnectionConfig] = Field(
103+
default=None,
104+
description="AWS configuration for S3 access. Required when query_file is an S3 URI (s3://).",
105+
)
85106

86107

87108
@dataclass
88109
class SqlQueriesSourceReport(SourceReport):
89110
num_entries_processed: int = 0
90111
num_entries_failed: int = 0
91112
num_queries_aggregator_failures: int = 0
113+
num_queries_processed_sequential: int = 0
114+
num_temp_tables_detected: int = 0
115+
temp_table_patterns_used: List[str] = field(default_factory=list)
116+
peak_memory_usage_mb: float = 0.0
92117

93118
sql_aggregator: Optional[SqlAggregatorReport] = None
119+
schema_resolver_report: Optional[SchemaResolverReport] = None
94120

95121

96122
@platform_name("SQL Queries", id="sql-queries")
@@ -115,6 +141,18 @@ class SqlQueriesSource(Source):
115141
- upstream_tables (optional): string[] - Fallback list of tables the query reads from,
116142
used if the query can't be parsed.
117143
144+
**Lazy Schema Loading**:
145+
- Fetches schemas on-demand during query parsing instead of bulk loading all schemas upfront
146+
- Caches fetched schemas for future lookups to avoid repeated network requests
147+
- Reduces initial startup time and memory usage significantly
148+
- Automatically handles large platforms efficiently without memory issues
149+
150+
**Query Processing**:
151+
- Loads the entire query file into memory at once
152+
- Processes all queries sequentially before generating metadata work units
153+
- Preserves temp table mappings and lineage relationships to ensure consistent lineage tracking
154+
- Query deduplication is handled automatically by the SQL parsing aggregator
155+
118156
### Incremental Lineage
119157
When `incremental_lineage` is enabled, this source will emit lineage as patches rather than full overwrites.
120158
This allows you to add lineage edges without removing existing ones, which is useful for:
@@ -124,6 +162,12 @@ class SqlQueriesSource(Source):
124162
125163
Note: Incremental lineage only applies to UpstreamLineage aspects. Other aspects like queries and usage
126164
statistics will still be emitted normally.
165+
166+
### Temporary Table Support
167+
For platforms like Athena that don't have native temporary tables, you can use the `temp_table_patterns`
168+
configuration to specify regex patterns that identify fake temporary tables. This allows the source to
169+
process these tables like other sources that support native temp tables, enabling proper lineage tracking
170+
across temporary table operations.
127171
"""
128172

129173
schema_resolver: Optional[SchemaResolver]
@@ -141,13 +185,19 @@ def __init__(self, ctx: PipelineContext, config: SqlQueriesSourceConfig):
141185
self.report = SqlQueriesSourceReport()
142186

143187
if self.config.use_schema_resolver:
144-
# TODO: `initialize_schema_resolver_from_datahub` does a bulk initialization by fetching all schemas
145-
# for the given platform, platform instance, and env. Instead this should be configurable:
146-
# bulk initialization vs lazy on-demand schema fetching.
147-
self.schema_resolver = self.graph.initialize_schema_resolver_from_datahub(
188+
# Create schema resolver report for tracking
189+
self.report.schema_resolver_report = SchemaResolverReport()
190+
191+
# Use lazy loading - schemas will be fetched on-demand and cached
192+
logger.info(
193+
"Using lazy schema loading - schemas will be fetched on-demand and cached"
194+
)
195+
self.schema_resolver = SchemaResolver(
148196
platform=self.config.platform,
149197
platform_instance=self.config.platform_instance,
150198
env=self.config.env,
199+
graph=self.graph,
200+
report=self.report.schema_resolver_report,
151201
)
152202
else:
153203
self.schema_resolver = None
@@ -156,7 +206,9 @@ def __init__(self, ctx: PipelineContext, config: SqlQueriesSourceConfig):
156206
platform=self.config.platform,
157207
platform_instance=self.config.platform_instance,
158208
env=self.config.env,
159-
schema_resolver=self.schema_resolver,
209+
schema_resolver=cast(SchemaResolver, self.schema_resolver)
210+
if self.schema_resolver
211+
else None,
160212
eager_graph_load=False,
161213
generate_lineage=True, # TODO: make this configurable
162214
generate_queries=True, # TODO: make this configurable
@@ -165,7 +217,9 @@ def __init__(self, ctx: PipelineContext, config: SqlQueriesSourceConfig):
165217
generate_usage_statistics=True,
166218
generate_operations=True, # TODO: make this configurable
167219
usage_config=self.config.usage,
168-
is_temp_table=None,
220+
is_temp_table=self.is_temp_table
221+
if self.config.temp_table_patterns
222+
else None,
169223
is_allowed_table=None,
170224
format_queries=False,
171225
)
@@ -193,20 +247,73 @@ def get_workunits_internal(
193247
) -> Iterable[Union[MetadataWorkUnit, MetadataChangeProposalWrapper]]:
194248
logger.info(f"Parsing queries from {os.path.basename(self.config.query_file)}")
195249

250+
logger.info("Processing all queries in batch mode")
251+
yield from self._process_queries_batch()
252+
253+
def _process_queries_batch(
254+
self,
255+
) -> Iterable[Union[MetadataWorkUnit, MetadataChangeProposalWrapper]]:
256+
"""Process all queries in memory (original behavior)."""
196257
with self.report.new_stage("Collecting queries from file"):
197258
queries = list(self._parse_query_file())
198259
logger.info(f"Collected {len(queries)} queries for processing")
199260

200261
with self.report.new_stage("Processing queries through SQL parsing aggregator"):
201-
for query_entry in queries:
202-
self._add_query_to_aggregator(query_entry)
262+
logger.info("Using sequential processing")
263+
self._process_queries_sequential(queries)
203264

204265
with self.report.new_stage("Generating metadata work units"):
205266
logger.info("Generating workunits from SQL parsing aggregator")
206-
yield from self.aggregator.gen_metadata()
267+
yield from auto_workunit(self.aggregator.gen_metadata())
207268

208-
def _parse_query_file(self) -> Iterable["QueryEntry"]:
209-
"""Parse the query file and yield QueryEntry objects."""
269+
def _is_s3_uri(self, path: str) -> bool:
270+
"""Check if the path is an S3 URI."""
271+
return path.startswith("s3://")
272+
273+
def _parse_s3_query_file(self) -> Iterable["QueryEntry"]:
274+
"""Parse query file from S3 using smart_open."""
275+
if not self.config.aws_config:
276+
raise ValueError("AWS configuration required for S3 file access")
277+
278+
logger.info(f"Reading query file from S3: {self.config.query_file}")
279+
280+
try:
281+
# Use smart_open for efficient S3 streaming, similar to S3FileSystem
282+
s3_client = self.config.aws_config.get_s3_client()
283+
284+
with smart_open.open(
285+
self.config.query_file, mode="r", transport_params={"client": s3_client}
286+
) as file_stream:
287+
for line in file_stream:
288+
if line.strip():
289+
try:
290+
query_dict = json.loads(line, strict=False)
291+
entry = QueryEntry.create(query_dict, config=self.config)
292+
self.report.num_entries_processed += 1
293+
if self.report.num_entries_processed % 1000 == 0:
294+
logger.info(
295+
f"Processed {self.report.num_entries_processed} query entries from S3"
296+
)
297+
yield entry
298+
except Exception as e:
299+
self.report.num_entries_failed += 1
300+
self.report.warning(
301+
title="Error processing query from S3",
302+
message="Query skipped due to parsing error",
303+
context=line.strip(),
304+
exc=e,
305+
)
306+
except Exception as e:
307+
self.report.warning(
308+
title="Error reading S3 file",
309+
message="Failed to read S3 file",
310+
context=self.config.query_file,
311+
exc=e,
312+
)
313+
raise
314+
315+
def _parse_local_query_file(self) -> Iterable["QueryEntry"]:
316+
"""Parse local query file (existing logic)."""
210317
with open(self.config.query_file) as f:
211318
for line in f:
212319
try:
@@ -227,6 +334,30 @@ def _parse_query_file(self) -> Iterable["QueryEntry"]:
227334
exc=e,
228335
)
229336

337+
def _parse_query_file(self) -> Iterable["QueryEntry"]:
338+
"""Parse the query file and yield QueryEntry objects."""
339+
if self._is_s3_uri(self.config.query_file):
340+
yield from self._parse_s3_query_file()
341+
else:
342+
yield from self._parse_local_query_file()
343+
344+
def _process_queries_sequential(self, queries: List["QueryEntry"]) -> None:
345+
"""Process queries sequentially."""
346+
total_queries = len(queries)
347+
logger.info(f"Processing {total_queries} queries sequentially")
348+
349+
# Process each query sequentially
350+
for i, query_entry in enumerate(queries):
351+
self._add_query_to_aggregator(query_entry)
352+
self.report.num_queries_processed_sequential += 1
353+
354+
# Simple progress reporting every 1000 queries
355+
if (i + 1) % 1000 == 0:
356+
progress_pct = ((i + 1) / total_queries) * 100
357+
logger.info(
358+
f"Processed {i + 1}/{total_queries} queries ({progress_pct:.1f}%)"
359+
)
360+
230361
def _add_query_to_aggregator(self, query_entry: "QueryEntry") -> None:
231362
"""Add a query to the SQL parsing aggregator."""
232363
try:
@@ -285,6 +416,24 @@ def _add_query_to_aggregator(self, query_entry: "QueryEntry") -> None:
285416
exc=e,
286417
)
287418

419+
def is_temp_table(self, name: str) -> bool:
420+
"""Check if a table name matches any of the configured temp table patterns."""
421+
if not self.config.temp_table_patterns:
422+
return False
423+
424+
try:
425+
for pattern in self.config.temp_table_patterns:
426+
if re.match(pattern, name, flags=re.IGNORECASE):
427+
logger.debug(
428+
f"Table '{name}' matched temp table pattern: {pattern}"
429+
)
430+
self.report.num_temp_tables_detected += 1
431+
return True
432+
except re.error as e:
433+
logger.warning(f"Invalid regex pattern '{pattern}': {e}")
434+
435+
return False
436+
288437

289438
class QueryEntry(BaseModel):
290439
query: str

0 commit comments

Comments
 (0)