Skip to content

Commit 4164348

Browse files
Add Polars and PyArrow insert support
- insert() now auto-detects polars.DataFrame and pyarrow.Table - Converts via to_dicts()/to_pylist() internally - Soft dependency: no import required, checks type name 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 5b5d09a commit 4164348

File tree

2 files changed

+125
-4
lines changed

2 files changed

+125
-4
lines changed

src/datajoint/table.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,10 +602,10 @@ def insert(
602602
Insert a collection of rows.
603603
604604
:param rows: Either (a) an iterable where an element is a numpy record, a
605-
dict-like object, a pandas.DataFrame, a sequence, or a query expression with
606-
the same heading as self, or (b) a pathlib.Path object specifying a path
607-
relative to the current directory with a CSV file, the contents of which
608-
will be inserted.
605+
dict-like object, a pandas.DataFrame, a polars.DataFrame, a pyarrow.Table,
606+
a sequence, or a query expression with the same heading as self, or
607+
(b) a pathlib.Path object specifying a path relative to the current
608+
directory with a CSV file, the contents of which will be inserted.
609609
:param replace: If True, replaces the existing tuple.
610610
:param skip_duplicates: If True, silently skip duplicate inserts.
611611
:param ignore_extra_fields: If False, fields that are not in the heading raise error.
@@ -628,6 +628,14 @@ def insert(
628628
# frames with more advanced indices should be prepared by user.
629629
rows = rows.reset_index(drop=len(rows.index.names) == 1 and not rows.index.names[0]).to_records(index=False)
630630

631+
# Polars DataFrame -> list of dicts (soft dependency, check by type name)
632+
if type(rows).__module__.startswith("polars") and type(rows).__name__ == "DataFrame":
633+
rows = rows.to_dicts()
634+
635+
# PyArrow Table -> list of dicts (soft dependency, check by type name)
636+
if type(rows).__module__.startswith("pyarrow") and type(rows).__name__ == "Table":
637+
rows = rows.to_pylist()
638+
631639
if isinstance(rows, Path):
632640
with open(rows, newline="") as data_file:
633641
rows = list(csv.DictReader(data_file, delimiter=","))

tests/integration/test_insert.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,119 @@ def test_insert_dataframe_with_chunk_size(self, schema_insert):
258258
assert len(table) == 100
259259

260260

261+
try:
262+
import polars
263+
264+
HAS_POLARS = True
265+
except ImportError:
266+
HAS_POLARS = False
267+
268+
try:
269+
import pyarrow
270+
271+
HAS_PYARROW = True
272+
except ImportError:
273+
HAS_PYARROW = False
274+
275+
276+
@pytest.mark.skipif(not HAS_POLARS, reason="polars not installed")
277+
class TestPolarsInsert:
278+
"""Tests for Polars DataFrame insert support."""
279+
280+
def test_insert_polars_basic(self, schema_insert):
281+
"""Test inserting a Polars DataFrame."""
282+
table = SimpleTable()
283+
df = polars.DataFrame({"id": [1, 2, 3], "value": ["a", "b", "c"], "score": [1.0, 2.0, 3.0]})
284+
table.insert(df)
285+
assert len(table) == 3
286+
assert set(table.to_arrays("id")) == {1, 2, 3}
287+
288+
def test_insert_polars_with_options(self, schema_insert):
289+
"""Test Polars insert with skip_duplicates and chunk_size."""
290+
table = SimpleTable()
291+
df = polars.DataFrame({"id": [1, 2], "value": ["a", "b"], "score": [1.0, 2.0]})
292+
table.insert(df)
293+
294+
# Insert more with duplicates
295+
df2 = polars.DataFrame({"id": [2, 3, 4], "value": ["b", "c", "d"], "score": [2.0, 3.0, 4.0]})
296+
table.insert(df2, skip_duplicates=True)
297+
assert len(table) == 4
298+
299+
def test_insert_polars_chunk_size(self, schema_insert):
300+
"""Test Polars insert with chunk_size."""
301+
table = SimpleTable()
302+
df = polars.DataFrame(
303+
{"id": list(range(50)), "value": [f"v{i}" for i in range(50)], "score": [float(i) for i in range(50)]}
304+
)
305+
table.insert(df, chunk_size=10)
306+
assert len(table) == 50
307+
308+
def test_insert_polars_roundtrip(self, schema_insert):
309+
"""Test roundtrip: to_polars() -> insert()."""
310+
table = SimpleTable()
311+
table.insert([{"id": i, "value": f"val{i}", "score": float(i)} for i in range(3)])
312+
313+
# Fetch as Polars
314+
df = table.to_polars()
315+
assert isinstance(df, polars.DataFrame)
316+
317+
# Clear and re-insert
318+
with dj.config.override(safemode=False):
319+
table.delete()
320+
321+
table.insert(df)
322+
assert len(table) == 3
323+
324+
325+
@pytest.mark.skipif(not HAS_PYARROW, reason="pyarrow not installed")
326+
class TestArrowInsert:
327+
"""Tests for PyArrow Table insert support."""
328+
329+
def test_insert_arrow_basic(self, schema_insert):
330+
"""Test inserting a PyArrow Table."""
331+
table = SimpleTable()
332+
arrow_table = pyarrow.table({"id": [1, 2, 3], "value": ["a", "b", "c"], "score": [1.0, 2.0, 3.0]})
333+
table.insert(arrow_table)
334+
assert len(table) == 3
335+
assert set(table.to_arrays("id")) == {1, 2, 3}
336+
337+
def test_insert_arrow_with_options(self, schema_insert):
338+
"""Test Arrow insert with skip_duplicates."""
339+
table = SimpleTable()
340+
arrow_table = pyarrow.table({"id": [1, 2], "value": ["a", "b"], "score": [1.0, 2.0]})
341+
table.insert(arrow_table)
342+
343+
# Insert more with duplicates
344+
arrow_table2 = pyarrow.table({"id": [2, 3, 4], "value": ["b", "c", "d"], "score": [2.0, 3.0, 4.0]})
345+
table.insert(arrow_table2, skip_duplicates=True)
346+
assert len(table) == 4
347+
348+
def test_insert_arrow_chunk_size(self, schema_insert):
349+
"""Test Arrow insert with chunk_size."""
350+
table = SimpleTable()
351+
arrow_table = pyarrow.table(
352+
{"id": list(range(50)), "value": [f"v{i}" for i in range(50)], "score": [float(i) for i in range(50)]}
353+
)
354+
table.insert(arrow_table, chunk_size=10)
355+
assert len(table) == 50
356+
357+
def test_insert_arrow_roundtrip(self, schema_insert):
358+
"""Test roundtrip: to_arrow() -> insert()."""
359+
table = SimpleTable()
360+
table.insert([{"id": i, "value": f"val{i}", "score": float(i)} for i in range(3)])
361+
362+
# Fetch as Arrow
363+
arrow_table = table.to_arrow()
364+
assert isinstance(arrow_table, pyarrow.Table)
365+
366+
# Clear and re-insert
367+
with dj.config.override(safemode=False):
368+
table.delete()
369+
370+
table.insert(arrow_table)
371+
assert len(table) == 3
372+
373+
261374
class TestDeprecationWarning:
262375
"""Tests for positional insert deprecation warning."""
263376

0 commit comments

Comments
 (0)