Skip to content

Commit 7f7d0df

Browse files
committed
parallel streaming: Create table before starting parallel workers
1 parent 66f6c1b commit 7f7d0df

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

src/amp/loaders/implementations/snowflake_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,12 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) ->
163163
'Please use APPEND mode or manually truncate/drop the table before loading.'
164164
)
165165

166+
# Table creation is now handled by base class or pre-flight creation in parallel mode
167+
# For pandas loading, we skip manual table creation and let write_pandas handle it
166168
if create_table and table_name.upper() not in self._created_tables:
167-
self._create_table_from_schema(batch.schema, table_name)
169+
# For pandas, skip table creation - write_pandas will handle it
170+
if self.loading_method != 'pandas':
171+
self._create_table_from_schema(batch.schema, table_name)
168172
self._created_tables.add(table_name.upper())
169173

170174
if self.use_stage:

src/amp/streaming/parallel.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,15 +366,89 @@ def execute_parallel_stream(
366366
f'Starting parallel streaming with {len(partitions)} partitions across {self.config.num_workers} workers'
367367
)
368368

369-
# 2. Submit worker tasks
369+
# 2. Pre-flight table creation (before workers start)
370+
# Create table once to avoid locking complexity in parallel workers
371+
try:
372+
# Get connection info
373+
connection_info = self.client.connection_manager.get_connection_info(connection_name)
374+
loader_config = connection_info['config']
375+
loader_type = connection_info['loader']
376+
377+
# Get sample schema by executing LIMIT 1 on original query
378+
# We don't need partition filtering for schema detection, just need any row
379+
sample_query = user_query.strip().rstrip(';')
380+
381+
# Remove SETTINGS clause (especially stream = true) to avoid streaming mode
382+
sample_query_upper = sample_query.upper()
383+
settings_pos = sample_query_upper.find(' SETTINGS ')
384+
if settings_pos != -1:
385+
sample_query = sample_query[:settings_pos].rstrip()
386+
sample_query_upper = sample_query.upper()
387+
388+
# Insert LIMIT 1 before ORDER BY, GROUP BY if present
389+
end_keywords = [' ORDER BY ', ' GROUP BY ']
390+
insert_pos = len(sample_query)
391+
392+
for keyword in end_keywords:
393+
keyword_pos = sample_query_upper.find(keyword)
394+
if keyword_pos != -1 and keyword_pos < insert_pos:
395+
insert_pos = keyword_pos
396+
397+
# Insert LIMIT 1 at the correct position
398+
sample_query = sample_query[:insert_pos].rstrip() + ' LIMIT 1' + sample_query[insert_pos:]
399+
400+
self.logger.debug(f"Fetching schema with sample query: {sample_query[:100]}...")
401+
sample_table = self.client.get_sql(sample_query, read_all=True)
402+
403+
if sample_table.num_rows > 0:
404+
# Create loader instance to get effective schema and create table
405+
from ..loaders.registry import create_loader
406+
407+
loader_instance = create_loader(loader_type, loader_config, label_manager=self.client.label_manager)
408+
409+
try:
410+
loader_instance.connect()
411+
412+
# Get effective schema (includes labels if configured)
413+
sample_batch = sample_table.to_batches()[0]
414+
effective_schema = loader_instance._get_effective_schema(
415+
sample_batch.schema,
416+
load_config.get('label'),
417+
load_config.get('label_key_column'),
418+
load_config.get('stream_key_column')
419+
)
420+
421+
# Create table once with effective schema
422+
if hasattr(loader_instance, '_create_table_from_schema'):
423+
loader_instance._create_table_from_schema(effective_schema, destination)
424+
loader_instance._created_tables.add(destination)
425+
self.logger.info(
426+
f"Pre-created table '{destination}' with {len(effective_schema)} columns "
427+
f"(includes label columns if configured)"
428+
)
429+
else:
430+
self.logger.warning(f"Loader does not support table creation, workers will handle it")
431+
finally:
432+
loader_instance.disconnect()
433+
else:
434+
self.logger.warning("Sample query returned no rows, skipping pre-flight table creation")
435+
436+
# Update load_config to skip table creation in workers
437+
load_config['create_table'] = False
438+
439+
except Exception as e:
440+
self.logger.warning(f"Pre-flight table creation failed: {e}. Workers will attempt table creation with locking.")
441+
# Don't fail the entire job - let workers try to create the table
442+
443+
# 3. Submit worker tasks
370444
futures = {}
371445
for partition in partitions:
372446
future = self.executor.submit(
373447
self._execute_partition, user_query, partition, destination, connection_name, load_config
374448
)
375449
futures[future] = partition
376450

377-
# 3. Stream results as they complete
451+
# 4. Stream results as they complete
378452
try:
379453
for future in as_completed(futures):
380454
partition = futures[future]
@@ -406,7 +480,7 @@ def execute_parallel_stream(
406480
self.executor.shutdown(wait=True)
407481
self._log_final_stats()
408482

409-
# 4. If in hybrid mode, transition to continuous streaming for live blocks
483+
# 5. If in hybrid mode, transition to continuous streaming for live blocks
410484
if continue_streaming:
411485
# Start continuous streaming with a buffer for reorg overlap
412486
# This ensures we catch any reorgs that occurred during parallel catchup

0 commit comments

Comments
 (0)