Skip to content

Commit f89be6f

Browse files
Add job metadata migration utility and optimize progress()
- Add add_job_metadata_columns() migration utility to migrate.py - Adds hidden columns to existing Computed/Imported tables - Supports single tables or entire schemas - Dry-run mode for previewing changes - Optimize AutoPopulate.progress() with single aggregation query - Uses LEFT JOIN with COUNT(DISTINCT) for efficiency - Handles 1:many relationships correctly - Falls back to two-query method when no common attributes - Remove target property from AutoPopulate (always uses self) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 574b5f1 commit f89be6f

File tree

3 files changed

+227
-22
lines changed

3 files changed

+227
-22
lines changed

src/datajoint/autopopulate.py

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _rename_attributes(table, props):
127127
)
128128

129129
if self._key_source is None:
130-
parents = self.target.parents(primary=True, as_objects=True, foreign_key_info=True)
130+
parents = self.parents(primary=True, as_objects=True, foreign_key_info=True)
131131
if not parents:
132132
raise DataJointError("A table must have dependencies from its primary key for auto-populate to work")
133133
self._key_source = _rename_attributes(*parents[0])
@@ -204,15 +204,6 @@ def make(self, key):
204204
self.make_insert(key, *computed_result)
205205
yield
206206

207-
@property
208-
def target(self):
209-
"""
210-
:return: table to be populated.
211-
In the typical case, dj.AutoPopulate is mixed into a dj.Table class by
212-
inheritance and the target is self.
213-
"""
214-
return self
215-
216207
def _jobs_to_do(self, restrictions):
217208
"""
218209
:return: the query yielding the keys to be computed (derived from self.key_source)
@@ -235,7 +226,7 @@ def _jobs_to_do(self, restrictions):
235226
raise DataJointError(
236227
"The populate target lacks attribute %s "
237228
"from the primary key of key_source"
238-
% next(name for name in todo.heading.primary_key if name not in self.target.heading)
229+
% next(name for name in todo.heading.primary_key if name not in self.heading)
239230
)
240231
except StopIteration:
241232
pass
@@ -324,7 +315,7 @@ def _populate_direct(
324315
Computes keys directly from key_source, suitable for single-worker
325316
execution, development, and debugging.
326317
"""
327-
keys = (self._jobs_to_do(restrictions) - self.target).fetch("KEY")
318+
keys = (self._jobs_to_do(restrictions) - self).fetch("KEY")
328319

329320
logger.debug("Found %d keys to populate" % len(keys))
330321

@@ -493,14 +484,14 @@ def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_
493484
if not is_generator:
494485
self.connection.start_transaction()
495486

496-
if key in self.target: # already populated
487+
if key in self: # already populated
497488
if not is_generator:
498489
self.connection.cancel_transaction()
499490
if jobs is not None:
500491
jobs.complete(key)
501492
return False
502493

503-
logger.debug(f"Making {key} -> {self.target.full_table_name}")
494+
logger.debug(f"Making {key} -> {self.full_table_name}")
504495
self.__class__._allow_insert = True
505496

506497
try:
@@ -531,7 +522,7 @@ def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_
531522
exception=error.__class__.__name__,
532523
msg=": " + str(error) if str(error) else "",
533524
)
534-
logger.debug(f"Error making {key} -> {self.target.full_table_name} - {error_message}")
525+
logger.debug(f"Error making {key} -> {self.full_table_name} - {error_message}")
535526
if jobs is not None:
536527
jobs.error(key, error_message=error_message, error_stack=traceback.format_exc())
537528
if not suppress_errors or isinstance(error, SystemExit):
@@ -542,7 +533,7 @@ def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_
542533
else:
543534
self.connection.commit_transaction()
544535
duration = time.time() - start_time
545-
logger.debug(f"Success making {key} -> {self.target.full_table_name}")
536+
logger.debug(f"Success making {key} -> {self.full_table_name}")
546537

547538
# Update hidden job metadata if table has the columns
548539
if self._has_job_metadata_attrs():
@@ -564,11 +555,61 @@ def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_
564555
def progress(self, *restrictions, display=False):
565556
"""
566557
Report the progress of populating the table.
558+
559+
Uses a single aggregation query to efficiently compute both total and
560+
remaining counts.
561+
562+
:param restrictions: conditions to restrict key_source
563+
:param display: if True, log the progress
567564
:return: (remaining, total) -- numbers of tuples to be populated
568565
"""
569566
todo = self._jobs_to_do(restrictions)
570-
total = len(todo)
571-
remaining = len(todo - self.target)
567+
568+
# Get primary key attributes from key_source for join condition
569+
# These are the "job keys" - the granularity at which populate() works
570+
pk_attrs = todo.primary_key
571+
assert pk_attrs, "key_source must have a primary key"
572+
573+
# Find common attributes between key_source and self for the join
574+
# This handles cases where self has additional PK attributes
575+
common_attrs = [attr for attr in pk_attrs if attr in self.heading.names]
576+
577+
if not common_attrs:
578+
# No common attributes - fall back to two-query method
579+
total = len(todo)
580+
remaining = len(todo - self)
581+
else:
582+
# Build a single query that computes both total and remaining
583+
# Using LEFT JOIN with COUNT(DISTINCT) to handle 1:many relationships
584+
todo_sql = todo.make_sql()
585+
target_sql = self.make_sql()
586+
587+
# Build join condition on common attributes
588+
join_cond = " AND ".join(f"`$ks`.`{attr}` = `$tgt`.`{attr}`" for attr in common_attrs)
589+
590+
# Build DISTINCT key expression for counting unique jobs
591+
# Use CONCAT for composite keys to create a single distinct value
592+
if len(pk_attrs) == 1:
593+
distinct_key = f"`$ks`.`{pk_attrs[0]}`"
594+
null_check = f"`$tgt`.`{common_attrs[0]}`"
595+
else:
596+
distinct_key = "CONCAT_WS('|', {})".format(", ".join(f"`$ks`.`{attr}`" for attr in pk_attrs))
597+
null_check = f"`$tgt`.`{common_attrs[0]}`"
598+
599+
# Single aggregation query:
600+
# - COUNT(DISTINCT key) gives total unique jobs in key_source
601+
# - Remaining = jobs where no matching target row exists
602+
sql = f"""
603+
SELECT
604+
COUNT(DISTINCT {distinct_key}) AS total,
605+
COUNT(DISTINCT CASE WHEN {null_check} IS NULL THEN {distinct_key} END) AS remaining
606+
FROM ({todo_sql}) AS `$ks`
607+
LEFT JOIN ({target_sql}) AS `$tgt` ON {join_cond}
608+
"""
609+
610+
result = self.connection.query(sql).fetchone()
611+
total, remaining = result
612+
572613
if display:
573614
logger.info(
574615
"%-20s" % self.__class__.__name__
@@ -585,7 +626,7 @@ def progress(self, *restrictions, display=False):
585626
def _has_job_metadata_attrs(self):
586627
"""Check if table has hidden job metadata columns."""
587628
# Access _attributes directly to include hidden attributes
588-
all_attrs = self.target.heading._attributes
629+
all_attrs = self.heading._attributes
589630
return all_attrs is not None and "_job_start_time" in all_attrs
590631

591632
def _update_job_metadata(self, key, start_time, duration, version):
@@ -600,9 +641,9 @@ def _update_job_metadata(self, key, start_time, duration, version):
600641
"""
601642
from .condition import make_condition
602643

603-
pk_condition = make_condition(self.target, key, set())
644+
pk_condition = make_condition(self, key, set())
604645
self.connection.query(
605-
f"UPDATE {self.target.full_table_name} SET "
646+
f"UPDATE {self.full_table_name} SET "
606647
"`_job_start_time`=%s, `_job_duration`=%s, `_job_version`=%s "
607648
f"WHERE {pk_condition}",
608649
args=(start_time, duration, version[:64] if version else ""),

src/datajoint/migrate.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,167 @@ def check_migration_status(schema: Schema) -> dict:
248248
"pending": sum(1 for c in columns if c["needs_migration"]),
249249
"columns": columns,
250250
}
251+
252+
253+
# =============================================================================
254+
# Job Metadata Migration
255+
# =============================================================================
256+
257+
# Hidden job metadata columns added by config.jobs.add_job_metadata
258+
JOB_METADATA_COLUMNS = [
259+
("_job_start_time", "datetime(3) DEFAULT NULL"),
260+
("_job_duration", "float DEFAULT NULL"),
261+
("_job_version", "varchar(64) DEFAULT ''"),
262+
]
263+
264+
265+
def _get_existing_columns(connection, database: str, table_name: str) -> set[str]:
266+
"""Get set of existing column names for a table."""
267+
result = connection.query(
268+
"""
269+
SELECT COLUMN_NAME
270+
FROM information_schema.COLUMNS
271+
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
272+
""",
273+
args=(database, table_name),
274+
)
275+
return {row[0] for row in result.fetchall()}
276+
277+
278+
def _is_autopopulated_table(table_name: str) -> bool:
279+
"""Check if a table name indicates a Computed or Imported table."""
280+
# Computed tables start with __ (but not part tables which have __ in middle)
281+
# Imported tables start with _ (but not __)
282+
if table_name.startswith("__"):
283+
# Computed table if no __ after the prefix
284+
return "__" not in table_name[2:]
285+
elif table_name.startswith("_"):
286+
# Imported table
287+
return True
288+
return False
289+
290+
291+
def add_job_metadata_columns(target, dry_run: bool = True) -> dict:
292+
"""
293+
Add hidden job metadata columns to existing Computed/Imported tables.
294+
295+
This migration utility adds the hidden columns (_job_start_time, _job_duration,
296+
_job_version) to tables that were created before config.jobs.add_job_metadata
297+
was enabled.
298+
299+
Args:
300+
target: Either a table class/instance (dj.Computed or dj.Imported) or
301+
a Schema object. If a Schema, all Computed/Imported tables in
302+
the schema will be processed.
303+
dry_run: If True (default), only preview changes without applying.
304+
305+
Returns:
306+
Dict with keys:
307+
- tables_analyzed: Number of tables checked
308+
- tables_modified: Number of tables that were/would be modified
309+
- columns_added: Total columns added across all tables
310+
- details: List of dicts with per-table information
311+
312+
Example:
313+
>>> import datajoint as dj
314+
>>> from datajoint.migrate import add_job_metadata_columns
315+
>>>
316+
>>> # Preview migration for a single table
317+
>>> result = add_job_metadata_columns(MyComputedTable, dry_run=True)
318+
>>> print(f"Would add {result['columns_added']} columns")
319+
>>>
320+
>>> # Apply migration to all tables in a schema
321+
>>> result = add_job_metadata_columns(schema, dry_run=False)
322+
>>> print(f"Modified {result['tables_modified']} tables")
323+
324+
Note:
325+
- Only Computed and Imported tables are modified (not Manual, Lookup, or Part tables)
326+
- Existing rows will have NULL values for _job_start_time and _job_duration
327+
- Future populate() calls will fill in metadata for new rows
328+
- This does NOT retroactively populate metadata for existing rows
329+
"""
330+
from .schemas import Schema
331+
from .table import Table
332+
333+
result = {
334+
"tables_analyzed": 0,
335+
"tables_modified": 0,
336+
"columns_added": 0,
337+
"details": [],
338+
}
339+
340+
# Determine tables to process
341+
if isinstance(target, Schema):
342+
schema = target
343+
# Get all user tables in the schema
344+
tables_query = """
345+
SELECT TABLE_NAME
346+
FROM information_schema.TABLES
347+
WHERE TABLE_SCHEMA = %s
348+
AND TABLE_TYPE = 'BASE TABLE'
349+
AND TABLE_NAME NOT LIKE '~%%'
350+
"""
351+
table_names = [row[0] for row in schema.connection.query(tables_query, args=(schema.database,)).fetchall()]
352+
tables_to_process = [
353+
(schema.database, name, schema.connection) for name in table_names if _is_autopopulated_table(name)
354+
]
355+
elif isinstance(target, type) and issubclass(target, Table):
356+
# Table class
357+
instance = target()
358+
tables_to_process = [(instance.database, instance.table_name, instance.connection)]
359+
elif isinstance(target, Table):
360+
# Table instance
361+
tables_to_process = [(target.database, target.table_name, target.connection)]
362+
else:
363+
raise DataJointError(f"target must be a Table class, Table instance, or Schema, got {type(target)}")
364+
365+
for database, table_name, connection in tables_to_process:
366+
result["tables_analyzed"] += 1
367+
368+
# Skip non-autopopulated tables
369+
if not _is_autopopulated_table(table_name):
370+
continue
371+
372+
# Check which columns need to be added
373+
existing_columns = _get_existing_columns(connection, database, table_name)
374+
columns_to_add = [(name, definition) for name, definition in JOB_METADATA_COLUMNS if name not in existing_columns]
375+
376+
if not columns_to_add:
377+
result["details"].append(
378+
{
379+
"table": f"{database}.{table_name}",
380+
"status": "already_migrated",
381+
"columns_added": 0,
382+
}
383+
)
384+
continue
385+
386+
# Generate and optionally execute ALTER statements
387+
table_detail = {
388+
"table": f"{database}.{table_name}",
389+
"status": "migrated" if not dry_run else "pending",
390+
"columns_added": len(columns_to_add),
391+
"sql_statements": [],
392+
}
393+
394+
for col_name, col_definition in columns_to_add:
395+
sql = f"ALTER TABLE `{database}`.`{table_name}` ADD COLUMN `{col_name}` {col_definition}"
396+
table_detail["sql_statements"].append(sql)
397+
398+
if not dry_run:
399+
try:
400+
connection.query(sql)
401+
logger.info(f"Added column {col_name} to {database}.{table_name}")
402+
except Exception as e:
403+
logger.error(f"Failed to add column {col_name} to {database}.{table_name}: {e}")
404+
table_detail["status"] = "error"
405+
table_detail["error"] = str(e)
406+
raise DataJointError(f"Migration failed: {e}") from e
407+
else:
408+
logger.info(f"Would add column {col_name} to {database}.{table_name}")
409+
410+
result["tables_modified"] += 1
411+
result["columns_added"] += len(columns_to_add)
412+
result["details"].append(table_detail)
413+
414+
return result

src/datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# version bump auto managed by Github Actions:
22
# label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit)
33
# manually set this version will be eventually overwritten by the above actions
4-
__version__ = "2.0.0a11"
4+
__version__ = "2.0.0a12"

0 commit comments

Comments
 (0)