Skip to content

Commit 1e5ff71

Browse files
committed
iceberg_loader: Add _table_cache - avoid fetching catalog for each batch
- Also use _create_table_from_schema() to match base DataLoader class
1 parent 6d6a64f commit 1e5ff71

File tree

1 file changed

+56
-31
lines changed

1 file changed

+56
-31
lines changed

src/amp/loaders/implementations/iceberg_loader.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(self, config: Dict[str, Any]):
8686
self._current_table: Optional[IcebergTable] = None
8787
self._namespace_exists: bool = False
8888
self.enable_statistics: bool = config.get('enable_statistics', True)
89+
self._table_cache: Dict[str, IcebergTable] = {} # Cache tables by identifier
8990

9091
def _get_required_config_fields(self) -> list[str]:
9192
"""Return required configuration fields"""
@@ -117,6 +118,7 @@ def disconnect(self) -> None:
117118
if self._catalog:
118119
self._catalog = None
119120

121+
self._table_cache.clear() # Clear table cache on disconnect
120122
self._is_connected = False
121123
self.logger.info('Iceberg loader disconnected')
122124

@@ -128,9 +130,16 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) ->
128130
# Fix timestamps for Iceberg compatibility
129131
table = self._fix_timestamps(table)
130132

131-
# Get or create the Iceberg table
133+
# Get the Iceberg table (already created by _create_table_from_schema if needed)
132134
mode = kwargs.get('mode', LoadMode.APPEND)
133-
iceberg_table = self._get_or_create_table(table_name, table.schema)
135+
table_identifier = f'{self.config.namespace}.{table_name}'
136+
137+
# Use cached table if available
138+
if table_identifier in self._table_cache:
139+
iceberg_table = self._table_cache[table_identifier]
140+
else:
141+
iceberg_table = self._catalog.load_table(table_identifier)
142+
self._table_cache[table_identifier] = iceberg_table
134143

135144
# Validate schema compatibility (unless overwriting)
136145
if mode != LoadMode.OVERWRITE:
@@ -141,15 +150,28 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) ->
141150

142151
return rows_written
143152

144-
def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
145-
"""Create table from Arrow schema"""
146-
# Iceberg handles table creation in _get_or_create_table
147-
self.logger.info(f"Iceberg will create table '{table_name}' on first write with appropriate schema")
148-
149153
def _clear_table(self, table_name: str) -> None:
150154
"""Clear table for overwrite mode"""
151155
# Iceberg handles overwrites internally
152-
self.logger.info(f"Iceberg will handle overwrite for table '{table_name}'")
156+
# Clear from cache to ensure fresh state after overwrite
157+
table_identifier = f'{self.config.namespace}.{table_name}'
158+
if table_identifier in self._table_cache:
159+
del self._table_cache[table_identifier]
160+
161+
def _fix_schema_timestamps(self, schema: pa.Schema) -> pa.Schema:
162+
"""Convert nanosecond timestamps to microseconds in schema for Iceberg compatibility"""
163+
# Check if conversion is needed
164+
if not any(pa.types.is_timestamp(f.type) and f.type.unit == 'ns' for f in schema):
165+
return schema
166+
167+
new_fields = []
168+
for field in schema:
169+
if pa.types.is_timestamp(field.type) and field.type.unit == 'ns':
170+
new_fields.append(pa.field(field.name, pa.timestamp('us', tz=field.type.tz)))
171+
else:
172+
new_fields.append(field)
173+
174+
return pa.schema(new_fields)
153175

154176
def _fix_timestamps(self, arrow_table: pa.Table) -> pa.Table:
155177
"""Convert nanosecond timestamps to microseconds for Iceberg compatibility"""
@@ -217,33 +239,36 @@ def _check_namespace_exists(self, namespace: str) -> None:
217239
except Exception as e:
218240
raise NoSuchNamespaceError(f"Failed to verify namespace '{namespace}': {str(e)}") from e
219241

220-
def _get_or_create_table(self, table_name: str, schema: pa.Schema) -> IcebergTable:
221-
"""Get existing table or create new one"""
242+
def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
243+
"""Create table if it doesn't exist - called once by base class before first batch"""
244+
if not self.config.create_table:
245+
# If create_table is False, just verify table exists
246+
table_identifier = f'{self.config.namespace}.{table_name}'
247+
try:
248+
table = self._catalog.load_table(table_identifier)
249+
# Cache the existing table
250+
self._table_cache[table_identifier] = table
251+
self.logger.debug(f'Table already exists: {table_identifier}')
252+
except (NoSuchTableError, NoSuchIcebergTableError) as e:
253+
raise NoSuchTableError(f"Table '{table_identifier}' not found and create_table=False") from e
254+
return
255+
222256
table_identifier = f'{self.config.namespace}.{table_name}'
223257

224-
try:
225-
table = self._catalog.load_table(table_identifier)
226-
self.logger.debug(f'Loaded existing table: {table_identifier}')
227-
return table
258+
# Fix timestamps in schema before creating table
259+
fixed_schema = self._fix_schema_timestamps(schema)
228260

229-
except (NoSuchTableError, NoSuchIcebergTableError) as e:
230-
if not self.config.create_table:
231-
raise NoSuchTableError(f"Table '{table_identifier}' not found and create_table=False") from e
232-
233-
try:
234-
# Use partition_spec if provided
235-
if self.config.partition_spec:
236-
table = self._catalog.create_table(
237-
identifier=table_identifier, schema=schema, partition_spec=self.config.partition_spec
238-
)
239-
else:
240-
# Create table without partitioning
241-
table = self._catalog.create_table(identifier=table_identifier, schema=schema)
242-
self.logger.info(f'Created new table: {table_identifier}')
243-
return table
261+
# Use create_table_if_not_exists for simpler logic
262+
if self.config.partition_spec:
263+
table = self._catalog.create_table_if_not_exists(
264+
identifier=table_identifier, schema=fixed_schema, partition_spec=self.config.partition_spec
265+
)
266+
else:
267+
table = self._catalog.create_table_if_not_exists(identifier=table_identifier, schema=fixed_schema)
244268

245-
except Exception as e:
246-
raise RuntimeError(f"Failed to create table '{table_identifier}': {str(e)}") from e
269+
# Cache the newly created/loaded table
270+
self._table_cache[table_identifier] = table
271+
self.logger.info(f"Table '{table_identifier}' ready (created if needed)")
247272

248273
def _validate_schema_compatibility(self, iceberg_table: IcebergTable, arrow_schema: pa.Schema) -> None:
249274
"""Validate that Arrow schema is compatible with Iceberg table schema and perform schema evolution if enabled"""

0 commit comments

Comments
 (0)