11import json
22import logging
33import os
4- from dataclasses import dataclass
4+ import re
5+ from dataclasses import dataclass , field
56from datetime import datetime
67from functools import partial
7- from typing import ClassVar , Iterable , List , Optional , Union
8+ from typing import ClassVar , Iterable , List , Optional , Union , cast
89
10+ import smart_open
911from pydantic import BaseModel , Field , validator
1012
1113from datahub .configuration .common import HiddenFromDocs
3638 SourceCapability ,
3739 SourceReport ,
3840)
39- from datahub .ingestion .api .source_helpers import auto_workunit_reporter
41+ from datahub .ingestion .api .source_helpers import auto_workunit , auto_workunit_reporter
4042from datahub .ingestion .api .workunit import MetadataWorkUnit
4143from datahub .ingestion .graph .client import DataHubGraph
44+ from datahub .ingestion .source .aws .aws_common import AwsConnectionConfig
4245from datahub .ingestion .source .usage .usage_common import BaseUsageConfig
4346from datahub .metadata .urns import CorpUserUrn , DatasetUrn
44- from datahub .sql_parsing .schema_resolver import SchemaResolver
47+ from datahub .sql_parsing .schema_resolver import SchemaResolver , SchemaResolverReport
4548from datahub .sql_parsing .sql_parsing_aggregator import (
4649 KnownQueryLineageInfo ,
4750 ObservedQuery ,
@@ -82,15 +85,38 @@ class SqlQueriesSourceConfig(
8285 None ,
8386 description = "The SQL dialect to use when parsing queries. Overrides automatic dialect detection." ,
8487 )
88+ temp_table_patterns : List [str ] = Field (
89+ description = "Regex patterns for temporary tables to filter in lineage ingestion. "
90+ "Specify regex to match the entire table name. This is useful for platforms like Athena "
91+ "that don't have native temp tables but use naming patterns for fake temp tables." ,
92+ default = [],
93+ )
94+
95+ enable_lazy_schema_loading : bool = Field (
96+ default = True ,
97+ description = "Enable lazy schema loading for better performance. When enabled, schemas are fetched on-demand "
98+ "instead of bulk loading all schemas upfront, reducing startup time and memory usage." ,
99+ )
100+
101+ # AWS/S3 configuration
102+ aws_config : Optional [AwsConnectionConfig ] = Field (
103+ default = None ,
104+ description = "AWS configuration for S3 access. Required when query_file is an S3 URI (s3://)." ,
105+ )
85106
86107
87108@dataclass
88109class SqlQueriesSourceReport (SourceReport ):
89110 num_entries_processed : int = 0
90111 num_entries_failed : int = 0
91112 num_queries_aggregator_failures : int = 0
113+ num_queries_processed_sequential : int = 0
114+ num_temp_tables_detected : int = 0
115+ temp_table_patterns_used : List [str ] = field (default_factory = list )
116+ peak_memory_usage_mb : float = 0.0
92117
93118 sql_aggregator : Optional [SqlAggregatorReport ] = None
119+ schema_resolver_report : Optional [SchemaResolverReport ] = None
94120
95121
96122@platform_name ("SQL Queries" , id = "sql-queries" )
@@ -115,6 +141,18 @@ class SqlQueriesSource(Source):
115141 - upstream_tables (optional): string[] - Fallback list of tables the query reads from,
116142 used if the query can't be parsed.
117143
144+ **Lazy Schema Loading**:
145+ - Fetches schemas on-demand during query parsing instead of bulk loading all schemas upfront
146+ - Caches fetched schemas for future lookups to avoid repeated network requests
147+ - Reduces initial startup time and memory usage significantly
148+ - Automatically handles large platforms efficiently without memory issues
149+
150+ **Query Processing**:
151+ - Loads the entire query file into memory at once
152+ - Processes all queries sequentially before generating metadata work units
153+ - Preserves temp table mappings and lineage relationships to ensure consistent lineage tracking
154+ - Query deduplication is handled automatically by the SQL parsing aggregator
155+
118156 ### Incremental Lineage
119157 When `incremental_lineage` is enabled, this source will emit lineage as patches rather than full overwrites.
120158 This allows you to add lineage edges without removing existing ones, which is useful for:
@@ -124,6 +162,12 @@ class SqlQueriesSource(Source):
124162
125163 Note: Incremental lineage only applies to UpstreamLineage aspects. Other aspects like queries and usage
126164 statistics will still be emitted normally.
165+
166+ ### Temporary Table Support
167+ For platforms like Athena that don't have native temporary tables, you can use the `temp_table_patterns`
168+ configuration to specify regex patterns that identify fake temporary tables. This allows the source to
169+ process these tables like other sources that support native temp tables, enabling proper lineage tracking
170+ across temporary table operations.
127171 """
128172
129173 schema_resolver : Optional [SchemaResolver ]
@@ -141,13 +185,19 @@ def __init__(self, ctx: PipelineContext, config: SqlQueriesSourceConfig):
141185 self .report = SqlQueriesSourceReport ()
142186
143187 if self .config .use_schema_resolver :
144- # TODO: `initialize_schema_resolver_from_datahub` does a bulk initialization by fetching all schemas
145- # for the given platform, platform instance, and env. Instead this should be configurable:
146- # bulk initialization vs lazy on-demand schema fetching.
147- self .schema_resolver = self .graph .initialize_schema_resolver_from_datahub (
188+ # Create schema resolver report for tracking
189+ self .report .schema_resolver_report = SchemaResolverReport ()
190+
191+ # Use lazy loading - schemas will be fetched on-demand and cached
192+ logger .info (
193+ "Using lazy schema loading - schemas will be fetched on-demand and cached"
194+ )
195+ self .schema_resolver = SchemaResolver (
148196 platform = self .config .platform ,
149197 platform_instance = self .config .platform_instance ,
150198 env = self .config .env ,
199+ graph = self .graph ,
200+ report = self .report .schema_resolver_report ,
151201 )
152202 else :
153203 self .schema_resolver = None
@@ -156,7 +206,9 @@ def __init__(self, ctx: PipelineContext, config: SqlQueriesSourceConfig):
156206 platform = self .config .platform ,
157207 platform_instance = self .config .platform_instance ,
158208 env = self .config .env ,
159- schema_resolver = self .schema_resolver ,
209+ schema_resolver = cast (SchemaResolver , self .schema_resolver )
210+ if self .schema_resolver
211+ else None ,
160212 eager_graph_load = False ,
161213 generate_lineage = True , # TODO: make this configurable
162214 generate_queries = True , # TODO: make this configurable
@@ -165,7 +217,9 @@ def __init__(self, ctx: PipelineContext, config: SqlQueriesSourceConfig):
165217 generate_usage_statistics = True ,
166218 generate_operations = True , # TODO: make this configurable
167219 usage_config = self .config .usage ,
168- is_temp_table = None ,
220+ is_temp_table = self .is_temp_table
221+ if self .config .temp_table_patterns
222+ else None ,
169223 is_allowed_table = None ,
170224 format_queries = False ,
171225 )
@@ -193,20 +247,73 @@ def get_workunits_internal(
193247 ) -> Iterable [Union [MetadataWorkUnit , MetadataChangeProposalWrapper ]]:
194248 logger .info (f"Parsing queries from { os .path .basename (self .config .query_file )} " )
195249
250+ logger .info ("Processing all queries in batch mode" )
251+ yield from self ._process_queries_batch ()
252+
253+ def _process_queries_batch (
254+ self ,
255+ ) -> Iterable [Union [MetadataWorkUnit , MetadataChangeProposalWrapper ]]:
256+ """Process all queries in memory (original behavior)."""
196257 with self .report .new_stage ("Collecting queries from file" ):
197258 queries = list (self ._parse_query_file ())
198259 logger .info (f"Collected { len (queries )} queries for processing" )
199260
200261 with self .report .new_stage ("Processing queries through SQL parsing aggregator" ):
201- for query_entry in queries :
202- self ._add_query_to_aggregator ( query_entry )
262+ logger . info ( "Using sequential processing" )
263+ self ._process_queries_sequential ( queries )
203264
204265 with self .report .new_stage ("Generating metadata work units" ):
205266 logger .info ("Generating workunits from SQL parsing aggregator" )
206- yield from self .aggregator .gen_metadata ()
267+ yield from auto_workunit ( self .aggregator .gen_metadata () )
207268
208- def _parse_query_file (self ) -> Iterable ["QueryEntry" ]:
209- """Parse the query file and yield QueryEntry objects."""
269+ def _is_s3_uri (self , path : str ) -> bool :
270+ """Check if the path is an S3 URI."""
271+ return path .startswith ("s3://" )
272+
273+ def _parse_s3_query_file (self ) -> Iterable ["QueryEntry" ]:
274+ """Parse query file from S3 using smart_open."""
275+ if not self .config .aws_config :
276+ raise ValueError ("AWS configuration required for S3 file access" )
277+
278+ logger .info (f"Reading query file from S3: { self .config .query_file } " )
279+
280+ try :
281+ # Use smart_open for efficient S3 streaming, similar to S3FileSystem
282+ s3_client = self .config .aws_config .get_s3_client ()
283+
284+ with smart_open .open (
285+ self .config .query_file , mode = "r" , transport_params = {"client" : s3_client }
286+ ) as file_stream :
287+ for line in file_stream :
288+ if line .strip ():
289+ try :
290+ query_dict = json .loads (line , strict = False )
291+ entry = QueryEntry .create (query_dict , config = self .config )
292+ self .report .num_entries_processed += 1
293+ if self .report .num_entries_processed % 1000 == 0 :
294+ logger .info (
295+ f"Processed { self .report .num_entries_processed } query entries from S3"
296+ )
297+ yield entry
298+ except Exception as e :
299+ self .report .num_entries_failed += 1
300+ self .report .warning (
301+ title = "Error processing query from S3" ,
302+ message = "Query skipped due to parsing error" ,
303+ context = line .strip (),
304+ exc = e ,
305+ )
306+ except Exception as e :
307+ self .report .warning (
308+ title = "Error reading S3 file" ,
309+ message = "Failed to read S3 file" ,
310+ context = self .config .query_file ,
311+ exc = e ,
312+ )
313+ raise
314+
315+ def _parse_local_query_file (self ) -> Iterable ["QueryEntry" ]:
316+ """Parse local query file (existing logic)."""
210317 with open (self .config .query_file ) as f :
211318 for line in f :
212319 try :
@@ -227,6 +334,30 @@ def _parse_query_file(self) -> Iterable["QueryEntry"]:
227334 exc = e ,
228335 )
229336
337+ def _parse_query_file (self ) -> Iterable ["QueryEntry" ]:
338+ """Parse the query file and yield QueryEntry objects."""
339+ if self ._is_s3_uri (self .config .query_file ):
340+ yield from self ._parse_s3_query_file ()
341+ else :
342+ yield from self ._parse_local_query_file ()
343+
344+ def _process_queries_sequential (self , queries : List ["QueryEntry" ]) -> None :
345+ """Process queries sequentially."""
346+ total_queries = len (queries )
347+ logger .info (f"Processing { total_queries } queries sequentially" )
348+
349+ # Process each query sequentially
350+ for i , query_entry in enumerate (queries ):
351+ self ._add_query_to_aggregator (query_entry )
352+ self .report .num_queries_processed_sequential += 1
353+
354+ # Simple progress reporting every 1000 queries
355+ if (i + 1 ) % 1000 == 0 :
356+ progress_pct = ((i + 1 ) / total_queries ) * 100
357+ logger .info (
358+ f"Processed { i + 1 } /{ total_queries } queries ({ progress_pct :.1f} %)"
359+ )
360+
230361 def _add_query_to_aggregator (self , query_entry : "QueryEntry" ) -> None :
231362 """Add a query to the SQL parsing aggregator."""
232363 try :
@@ -285,6 +416,24 @@ def _add_query_to_aggregator(self, query_entry: "QueryEntry") -> None:
285416 exc = e ,
286417 )
287418
419+ def is_temp_table (self , name : str ) -> bool :
420+ """Check if a table name matches any of the configured temp table patterns."""
421+ if not self .config .temp_table_patterns :
422+ return False
423+
424+ try :
425+ for pattern in self .config .temp_table_patterns :
426+ if re .match (pattern , name , flags = re .IGNORECASE ):
427+ logger .debug (
428+ f"Table '{ name } ' matched temp table pattern: { pattern } "
429+ )
430+ self .report .num_temp_tables_detected += 1
431+ return True
432+ except re .error as e :
433+ logger .warning (f"Invalid regex pattern '{ pattern } ': { e } " )
434+
435+ return False
436+
288437
289438class QueryEntry (BaseModel ):
290439 query : str
0 commit comments