Skip to content

Commit 66f143e

Browse files
Merge origin/claude/modern-fetch-api: Add Polars and PyArrow insert support
Resolved conflict in table.py docstring by combining NumPy-style format with new Polars/PyArrow type support. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
2 parents 1d2832b + 4164348 commit 66f143e

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

src/datajoint/table.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,8 @@ def insert(
741741
- iterable of dicts: ``[{"attr": value, ...}, ...]``
742742
- iterable of numpy.void: Records from a structured array
743743
- pandas.DataFrame: Each row becomes a table row
744+
- polars.DataFrame: Each row becomes a table row
745+
- pyarrow.Table: Each row becomes a table row
744746
- QueryExpression: Results of a query (insert from select)
745747
- pathlib.Path: Path to a CSV file
746748
@@ -781,6 +783,14 @@ def insert(
781783
# frames with more advanced indices should be prepared by user.
782784
rows = rows.reset_index(drop=len(rows.index.names) == 1 and not rows.index.names[0]).to_records(index=False)
783785

786+
# Polars DataFrame -> list of dicts (soft dependency, check by type name)
787+
if type(rows).__module__.startswith("polars") and type(rows).__name__ == "DataFrame":
788+
rows = rows.to_dicts()
789+
790+
# PyArrow Table -> list of dicts (soft dependency, check by type name)
791+
if type(rows).__module__.startswith("pyarrow") and type(rows).__name__ == "Table":
792+
rows = rows.to_pylist()
793+
784794
if isinstance(rows, Path):
785795
with open(rows, newline="") as data_file:
786796
rows = list(csv.DictReader(data_file, delimiter=","))

tests/integration/test_insert.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,119 @@ def test_insert_dataframe_with_chunk_size(self, schema_insert):
258258
assert len(table) == 100
259259

260260

261+
try:
262+
import polars
263+
264+
HAS_POLARS = True
265+
except ImportError:
266+
HAS_POLARS = False
267+
268+
try:
269+
import pyarrow
270+
271+
HAS_PYARROW = True
272+
except ImportError:
273+
HAS_PYARROW = False
274+
275+
276+
@pytest.mark.skipif(not HAS_POLARS, reason="polars not installed")
277+
class TestPolarsInsert:
278+
"""Tests for Polars DataFrame insert support."""
279+
280+
def test_insert_polars_basic(self, schema_insert):
281+
"""Test inserting a Polars DataFrame."""
282+
table = SimpleTable()
283+
df = polars.DataFrame({"id": [1, 2, 3], "value": ["a", "b", "c"], "score": [1.0, 2.0, 3.0]})
284+
table.insert(df)
285+
assert len(table) == 3
286+
assert set(table.to_arrays("id")) == {1, 2, 3}
287+
288+
def test_insert_polars_with_options(self, schema_insert):
289+
"""Test Polars insert with skip_duplicates and chunk_size."""
290+
table = SimpleTable()
291+
df = polars.DataFrame({"id": [1, 2], "value": ["a", "b"], "score": [1.0, 2.0]})
292+
table.insert(df)
293+
294+
# Insert more with duplicates
295+
df2 = polars.DataFrame({"id": [2, 3, 4], "value": ["b", "c", "d"], "score": [2.0, 3.0, 4.0]})
296+
table.insert(df2, skip_duplicates=True)
297+
assert len(table) == 4
298+
299+
def test_insert_polars_chunk_size(self, schema_insert):
300+
"""Test Polars insert with chunk_size."""
301+
table = SimpleTable()
302+
df = polars.DataFrame(
303+
{"id": list(range(50)), "value": [f"v{i}" for i in range(50)], "score": [float(i) for i in range(50)]}
304+
)
305+
table.insert(df, chunk_size=10)
306+
assert len(table) == 50
307+
308+
def test_insert_polars_roundtrip(self, schema_insert):
309+
"""Test roundtrip: to_polars() -> insert()."""
310+
table = SimpleTable()
311+
table.insert([{"id": i, "value": f"val{i}", "score": float(i)} for i in range(3)])
312+
313+
# Fetch as Polars
314+
df = table.to_polars()
315+
assert isinstance(df, polars.DataFrame)
316+
317+
# Clear and re-insert
318+
with dj.config.override(safemode=False):
319+
table.delete()
320+
321+
table.insert(df)
322+
assert len(table) == 3
323+
324+
325+
@pytest.mark.skipif(not HAS_PYARROW, reason="pyarrow not installed")
326+
class TestArrowInsert:
327+
"""Tests for PyArrow Table insert support."""
328+
329+
def test_insert_arrow_basic(self, schema_insert):
330+
"""Test inserting a PyArrow Table."""
331+
table = SimpleTable()
332+
arrow_table = pyarrow.table({"id": [1, 2, 3], "value": ["a", "b", "c"], "score": [1.0, 2.0, 3.0]})
333+
table.insert(arrow_table)
334+
assert len(table) == 3
335+
assert set(table.to_arrays("id")) == {1, 2, 3}
336+
337+
def test_insert_arrow_with_options(self, schema_insert):
338+
"""Test Arrow insert with skip_duplicates."""
339+
table = SimpleTable()
340+
arrow_table = pyarrow.table({"id": [1, 2], "value": ["a", "b"], "score": [1.0, 2.0]})
341+
table.insert(arrow_table)
342+
343+
# Insert more with duplicates
344+
arrow_table2 = pyarrow.table({"id": [2, 3, 4], "value": ["b", "c", "d"], "score": [2.0, 3.0, 4.0]})
345+
table.insert(arrow_table2, skip_duplicates=True)
346+
assert len(table) == 4
347+
348+
def test_insert_arrow_chunk_size(self, schema_insert):
349+
"""Test Arrow insert with chunk_size."""
350+
table = SimpleTable()
351+
arrow_table = pyarrow.table(
352+
{"id": list(range(50)), "value": [f"v{i}" for i in range(50)], "score": [float(i) for i in range(50)]}
353+
)
354+
table.insert(arrow_table, chunk_size=10)
355+
assert len(table) == 50
356+
357+
def test_insert_arrow_roundtrip(self, schema_insert):
358+
"""Test roundtrip: to_arrow() -> insert()."""
359+
table = SimpleTable()
360+
table.insert([{"id": i, "value": f"val{i}", "score": float(i)} for i in range(3)])
361+
362+
# Fetch as Arrow
363+
arrow_table = table.to_arrow()
364+
assert isinstance(arrow_table, pyarrow.Table)
365+
366+
# Clear and re-insert
367+
with dj.config.override(safemode=False):
368+
table.delete()
369+
370+
table.insert(arrow_table)
371+
assert len(table) == 3
372+
373+
261374
class TestDeprecationWarning:
262375
"""Tests for positional insert deprecation warning."""
263376

0 commit comments

Comments
 (0)