Skip to content

Commit 5b5d09a

Browse files
Add insert API improvements: validate(), chunk_size, insert_dataframe()
Non-breaking improvements to the insert API: 1. validate(rows) method - validates rows without inserting - Returns ValidationResult with is_valid, errors, rows_checked - Checks field existence, row format, codec validation, NULL constraints - Supports ignore_extra_fields parameter 2. chunk_size parameter for insert() - Enables memory-efficient batch inserts for large datasets - Each chunk is a separate transaction 3. insert_dataframe(df, index_as_pk=None) method - Explicit DataFrame index handling - Auto-detects when index matches primary key (symmetric with to_pandas()) - Supports index_as_pk=True/False for explicit control 4. Deprecation warning for positional inserts - Warns when using tuples/lists instead of dicts - Encourages explicit field names for clarity 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 6b56f25 commit 5b5d09a

File tree

4 files changed

+620
-2
lines changed

4 files changed

+620
-2
lines changed

src/datajoint/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
"logger",
5858
"cli",
5959
"ObjectRef",
60+
"ValidationResult",
6061
]
6162

6263
from . import errors
@@ -78,7 +79,7 @@
7879
from .objectref import ObjectRef
7980
from .schemas import Schema, VirtualModule, list_schemas
8081
from .settings import config
81-
from .table import FreeTable, Table
82+
from .table import FreeTable, Table, ValidationResult
8283
from .user_tables import Computed, Imported, Lookup, Manual, Part
8384
from .version import __version__
8485

src/datajoint/table.py

Lines changed: 288 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import logging
77
import re
88
import uuid
9+
import warnings
10+
from dataclasses import dataclass, field
911
from pathlib import Path
1012

1113
import numpy as np
@@ -54,6 +56,43 @@ class _RenameMap(tuple):
5456
pass
5557

5658

59+
@dataclass
60+
class ValidationResult:
61+
"""
62+
Result of table.validate() call.
63+
64+
Attributes:
65+
is_valid: True if all rows passed validation
66+
errors: List of (row_index, field_name, error_message) tuples
67+
rows_checked: Number of rows that were validated
68+
"""
69+
70+
is_valid: bool
71+
errors: list = field(default_factory=list) # list of (row_index, field_name | None, message)
72+
rows_checked: int = 0
73+
74+
def __bool__(self) -> bool:
75+
"""Allow using ValidationResult in boolean context."""
76+
return self.is_valid
77+
78+
def raise_if_invalid(self):
79+
"""Raise DataJointError if validation failed."""
80+
if not self.is_valid:
81+
raise DataJointError(self.summary())
82+
83+
def summary(self) -> str:
84+
"""Return formatted error summary."""
85+
if self.is_valid:
86+
return f"Validation passed: {self.rows_checked} rows checked"
87+
lines = [f"Validation failed: {len(self.errors)} error(s) in {self.rows_checked} rows"]
88+
for row_idx, field_name, message in self.errors[:10]: # Show first 10 errors
89+
field_str = f" in field '{field_name}'" if field_name else ""
90+
lines.append(f" Row {row_idx}{field_str}: {message}")
91+
if len(self.errors) > 10:
92+
lines.append(f" ... and {len(self.errors) - 10} more errors")
93+
return "\n".join(lines)
94+
95+
5796
class Table(QueryExpression):
5897
"""
5998
Table is an abstract class that represents a table in the schema.
@@ -375,6 +414,143 @@ def update1(self, row):
375414
)
376415
self.connection.query(query, args=list(r[2] for r in row if r[2] is not None))
377416

417+
def validate(self, rows, *, ignore_extra_fields=False) -> ValidationResult:
418+
"""
419+
Validate rows without inserting them.
420+
421+
:param rows: Same format as insert() - iterable of dicts, tuples, numpy records,
422+
or a pandas DataFrame.
423+
:param ignore_extra_fields: If True, ignore fields not in the table heading.
424+
:return: ValidationResult with is_valid, errors list, and rows_checked count.
425+
426+
Validates:
427+
- Field existence (all fields must be in table heading)
428+
- Row format (correct number of attributes for positional inserts)
429+
- Codec validation (type checking via codec.validate())
430+
- NULL constraints (non-nullable fields must have values)
431+
- Primary key completeness (all PK fields must be present)
432+
- UUID format and JSON serializability
433+
434+
Cannot validate (database-enforced):
435+
- Foreign key constraints
436+
- Unique constraints (other than PK)
437+
- Custom MySQL constraints
438+
439+
Example::
440+
441+
result = table.validate(rows)
442+
if result:
443+
table.insert(rows)
444+
else:
445+
print(result.summary())
446+
"""
447+
errors = []
448+
449+
# Convert DataFrame to records
450+
if isinstance(rows, pandas.DataFrame):
451+
rows = rows.reset_index(drop=len(rows.index.names) == 1 and not rows.index.names[0]).to_records(index=False)
452+
453+
# Convert Path (CSV) to list of dicts
454+
if isinstance(rows, Path):
455+
with open(rows, newline="") as data_file:
456+
rows = list(csv.DictReader(data_file, delimiter=","))
457+
458+
rows = list(rows) # Materialize iterator
459+
row_count = len(rows)
460+
461+
for row_idx, row in enumerate(rows):
462+
# Validate row format and fields
463+
row_dict = None
464+
try:
465+
if isinstance(row, np.void): # numpy record
466+
fields = list(row.dtype.fields.keys())
467+
row_dict = {name: row[name] for name in fields}
468+
elif isinstance(row, collections.abc.Mapping):
469+
fields = list(row.keys())
470+
row_dict = dict(row)
471+
else: # positional tuple/list
472+
if len(row) != len(self.heading):
473+
errors.append(
474+
(
475+
row_idx,
476+
None,
477+
f"Incorrect number of attributes: {len(row)} given, {len(self.heading)} expected",
478+
)
479+
)
480+
continue
481+
fields = list(self.heading.names)
482+
row_dict = dict(zip(fields, row))
483+
except TypeError:
484+
errors.append((row_idx, None, f"Invalid row type: {type(row).__name__}"))
485+
continue
486+
487+
# Check for unknown fields
488+
if not ignore_extra_fields:
489+
for field_name in fields:
490+
if field_name not in self.heading:
491+
errors.append((row_idx, field_name, f"Field '{field_name}' not in table heading"))
492+
493+
# Validate each field value
494+
for name in self.heading.names:
495+
if name not in row_dict:
496+
# Check if field is required (non-nullable, no default, not autoincrement)
497+
attr = self.heading[name]
498+
if not attr.nullable and attr.default is None and not attr.autoincrement:
499+
errors.append((row_idx, name, f"Required field '{name}' is missing"))
500+
continue
501+
502+
value = row_dict[name]
503+
attr = self.heading[name]
504+
505+
# Skip validation for None values on nullable columns
506+
if value is None:
507+
if not attr.nullable and attr.default is None:
508+
errors.append((row_idx, name, f"NULL value not allowed for non-nullable field '{name}'"))
509+
continue
510+
511+
# Codec validation
512+
if attr.codec:
513+
try:
514+
attr.codec.validate(value)
515+
except (TypeError, ValueError) as e:
516+
errors.append((row_idx, name, f"Codec validation failed: {e}"))
517+
continue
518+
519+
# UUID validation
520+
if attr.uuid and not isinstance(value, uuid.UUID):
521+
try:
522+
uuid.UUID(value)
523+
except (AttributeError, ValueError):
524+
errors.append((row_idx, name, f"Invalid UUID format: {value}"))
525+
continue
526+
527+
# JSON serialization check
528+
if attr.json:
529+
try:
530+
json.dumps(value)
531+
except (TypeError, ValueError) as e:
532+
errors.append((row_idx, name, f"Value not JSON serializable: {e}"))
533+
continue
534+
535+
# Numeric NaN check
536+
if attr.numeric and value != "" and not isinstance(value, bool):
537+
try:
538+
if np.isnan(float(value)):
539+
# NaN is allowed - will be converted to NULL
540+
pass
541+
except (TypeError, ValueError):
542+
# Not a number that can be checked for NaN - let it pass
543+
pass
544+
545+
# Check primary key completeness
546+
for pk_field in self.primary_key:
547+
if pk_field not in row_dict or row_dict[pk_field] is None:
548+
pk_attr = self.heading[pk_field]
549+
if not pk_attr.autoincrement:
550+
errors.append((row_idx, pk_field, f"Primary key field '{pk_field}' is missing or NULL"))
551+
552+
return ValidationResult(is_valid=len(errors) == 0, errors=errors, rows_checked=row_count)
553+
378554
def insert1(self, row, **kwargs):
379555
"""
380556
Insert one data record into the table. For ``kwargs``, see ``insert()``.
@@ -420,6 +596,7 @@ def insert(
420596
skip_duplicates=False,
421597
ignore_extra_fields=False,
422598
allow_direct_insert=None,
599+
chunk_size=None,
423600
):
424601
"""
425602
Insert a collection of rows.
@@ -434,12 +611,17 @@ def insert(
434611
:param ignore_extra_fields: If False, fields that are not in the heading raise error.
435612
:param allow_direct_insert: Only applies in auto-populated tables. If False (default),
436613
insert may only be called from inside the make callback.
614+
:param chunk_size: If set, insert rows in batches of this size. Useful for very
615+
large inserts to avoid memory issues. Each chunk is a separate transaction.
437616
438617
Example:
439618
440619
>>> Table.insert([
441620
>>> dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"),
442621
>>> dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")])
622+
623+
# Large insert with chunking
624+
>>> Table.insert(large_dataset, chunk_size=10000)
443625
"""
444626
if isinstance(rows, pandas.DataFrame):
445627
# drop 'extra' synthetic index for 1-field index case -
@@ -461,7 +643,9 @@ def insert(
461643
if inspect.isclass(rows) and issubclass(rows, QueryExpression):
462644
rows = rows() # instantiate if a class
463645
if isinstance(rows, QueryExpression):
464-
# insert from select
646+
# insert from select - chunk_size not applicable
647+
if chunk_size is not None:
648+
raise DataJointError("chunk_size is not supported for QueryExpression inserts")
465649
if not ignore_extra_fields:
466650
try:
467651
raise DataJointError(
@@ -485,6 +669,28 @@ def insert(
485669
self.connection.query(query)
486670
return
487671

672+
# Chunked insert mode
673+
if chunk_size is not None:
674+
rows_iter = iter(rows)
675+
while True:
676+
chunk = list(itertools.islice(rows_iter, chunk_size))
677+
if not chunk:
678+
break
679+
self._insert_rows(chunk, replace, skip_duplicates, ignore_extra_fields)
680+
return
681+
682+
# Single batch insert (original behavior)
683+
self._insert_rows(rows, replace, skip_duplicates, ignore_extra_fields)
684+
685+
def _insert_rows(self, rows, replace, skip_duplicates, ignore_extra_fields):
686+
"""
687+
Internal helper to insert a batch of rows.
688+
689+
:param rows: Iterable of rows to insert
690+
:param replace: If True, use REPLACE instead of INSERT
691+
:param skip_duplicates: If True, use ON DUPLICATE KEY UPDATE
692+
:param ignore_extra_fields: If True, ignore unknown fields
693+
"""
488694
# collects the field list from first row (passed by reference)
489695
field_list = []
490696
rows = list(self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows)
@@ -508,6 +714,81 @@ def insert(
508714
except DuplicateError as err:
509715
raise err.suggest("To ignore duplicate entries in insert, set skip_duplicates=True")
510716

717+
def insert_dataframe(self, df, index_as_pk=None, **insert_kwargs):
718+
"""
719+
Insert DataFrame with explicit index handling.
720+
721+
This method provides symmetry with to_pandas(): data fetched with to_pandas()
722+
(which sets primary key as index) can be modified and re-inserted using
723+
insert_dataframe() without manual index manipulation.
724+
725+
:param df: pandas DataFrame to insert
726+
:param index_as_pk: How to handle DataFrame index:
727+
- None (default): Auto-detect. Use index as primary key if index names
728+
match primary_key columns. Drop if unnamed RangeIndex.
729+
- True: Treat index as primary key columns. Raises if index names don't
730+
match table primary key.
731+
- False: Ignore index entirely (drop it).
732+
:param **insert_kwargs: Passed to insert() - replace, skip_duplicates,
733+
ignore_extra_fields, allow_direct_insert, chunk_size
734+
735+
Example::
736+
737+
# Round-trip with to_pandas()
738+
df = table.to_pandas() # PK becomes index
739+
df['value'] = df['value'] * 2 # Modify data
740+
table.insert_dataframe(df) # Auto-detects index as PK
741+
742+
# Explicit control
743+
table.insert_dataframe(df, index_as_pk=True) # Use index
744+
table.insert_dataframe(df, index_as_pk=False) # Ignore index
745+
"""
746+
if not isinstance(df, pandas.DataFrame):
747+
raise DataJointError("insert_dataframe requires a pandas DataFrame")
748+
749+
# Auto-detect if index should be used as PK
750+
if index_as_pk is None:
751+
index_as_pk = self._should_index_be_pk(df)
752+
753+
# Validate index if using as PK
754+
if index_as_pk:
755+
self._validate_index_columns(df)
756+
757+
# Prepare rows
758+
if index_as_pk:
759+
rows = df.reset_index(drop=False).to_records(index=False)
760+
else:
761+
rows = df.reset_index(drop=True).to_records(index=False)
762+
763+
self.insert(rows, **insert_kwargs)
764+
765+
def _should_index_be_pk(self, df) -> bool:
766+
"""
767+
Auto-detect if DataFrame index should map to primary key.
768+
769+
Returns True if:
770+
- Index has named columns that exactly match the table's primary key
771+
Returns False if:
772+
- Index is unnamed RangeIndex (synthetic index)
773+
- Index names don't match primary key
774+
"""
775+
# RangeIndex with no name -> False (synthetic index)
776+
if df.index.names == [None]:
777+
return False
778+
# Check if index names match PK columns
779+
index_names = set(n for n in df.index.names if n is not None)
780+
return index_names == set(self.primary_key)
781+
782+
def _validate_index_columns(self, df):
783+
"""Validate that index columns match the table's primary key."""
784+
index_names = [n for n in df.index.names if n is not None]
785+
if set(index_names) != set(self.primary_key):
786+
raise DataJointError(
787+
f"DataFrame index columns {index_names} do not match "
788+
f"table primary key {list(self.primary_key)}. "
789+
f"Use index_as_pk=False to ignore index, or reset_index() first."
790+
)
791+
511792
def delete_quick(self, get_count=False):
512793
"""
513794
Deletes the table without cascading and without user prompt.
@@ -921,6 +1202,12 @@ def check_fields(fields):
9211202
if name in row
9221203
]
9231204
else: # positional
1205+
warnings.warn(
1206+
"Positional inserts (tuples/lists) are deprecated and will be removed in a future version. "
1207+
"Use dict with explicit field names instead: table.insert1({'field': value, ...})",
1208+
DeprecationWarning,
1209+
stacklevel=4, # Point to user's insert()/insert1() call
1210+
)
9241211
try:
9251212
if len(row) != len(self.heading):
9261213
raise DataJointError(

src/datajoint/user_tables.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,14 @@
4242
"children",
4343
"insert",
4444
"insert1",
45+
"insert_dataframe",
4546
"update1",
47+
"validate",
4648
"drop",
4749
"drop_quick",
4850
"delete",
4951
"delete_quick",
52+
"staged_insert1",
5053
}
5154

5255

0 commit comments

Comments
 (0)