|
25 | 25 | FilenameCollisionHandler, |
26 | 26 | ) |
27 | 27 | from neo4j_graphrag.experimental.components.parquet_formatter import ( |
| 28 | + Neo4jGraphParquetFormatter, |
28 | 29 | sanitize_parquet_filestem, |
29 | 30 | ) |
30 | 31 | from neo4j_graphrag.experimental.components.kg_writer import ( |
@@ -697,3 +698,72 @@ async def test_parquet_writer_run_empty_graph() -> None: |
697 | 698 | assert stats["nodes_per_label"] == {} |
698 | 699 | assert stats["rel_per_type"] == {} |
699 | 700 | assert result.metadata["files"] == [] |
| 701 | + |
| 702 | + |
| 703 | +# --------------------------------------------------------------------------- |
| 704 | +# Neo4jGraphParquetFormatter._normalize_column_types |
| 705 | +# --------------------------------------------------------------------------- |
| 706 | + |
| 707 | + |
| 708 | +def test_normalize_column_types_single_row() -> None: |
| 709 | + rows = [{"age": 30, "name": "Alice"}] |
| 710 | + Neo4jGraphParquetFormatter._normalize_column_types(rows) |
| 711 | + assert rows == [{"age": 30, "name": "Alice"}] |
| 712 | + |
| 713 | + |
| 714 | +def test_normalize_column_types_homogeneous() -> None: |
| 715 | + rows = [{"age": 30}, {"age": 25}] |
| 716 | + Neo4jGraphParquetFormatter._normalize_column_types(rows) |
| 717 | + assert rows == [{"age": 30}, {"age": 25}] |
| 718 | + |
| 719 | + |
| 720 | +def test_normalize_column_types_mixed_str_int() -> None: |
| 721 | + rows: list[dict[str, Any]] = [{"age": "45"}, {"age": 30}] |
| 722 | + Neo4jGraphParquetFormatter._normalize_column_types(rows) |
| 723 | + assert rows == [{"age": "45"}, {"age": "30"}] |
| 724 | + |
| 725 | + |
| 726 | +def test_normalize_column_types_mixed_int_float() -> None: |
| 727 | + rows: list[dict[str, Any]] = [{"score": 3}, {"score": 3.5}] |
| 728 | + Neo4jGraphParquetFormatter._normalize_column_types(rows) |
| 729 | + assert rows == [{"score": 3.0}, {"score": 3.5}] |
| 730 | + |
| 731 | + |
| 732 | +def test_normalize_column_types_none_ignored() -> None: |
| 733 | + """None values should not influence type detection.""" |
| 734 | + rows: list[dict[str, Any]] = [{"age": None}, {"age": 30}] |
| 735 | + Neo4jGraphParquetFormatter._normalize_column_types(rows) |
| 736 | + assert rows == [{"age": None}, {"age": 30}] |
| 737 | + |
| 738 | + |
| 739 | +@pytest.mark.asyncio |
| 740 | +async def test_parquet_writer_mixed_property_types() -> None: |
| 741 | + """ParquetWriter succeeds when nodes of the same label have mixed property types.""" |
| 742 | + pytest.importorskip("pyarrow") |
| 743 | + import pyarrow.parquet as pq |
| 744 | + |
| 745 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 746 | + out = Path(tmpdir) |
| 747 | + dest = _LocalParquetDestination(out) |
| 748 | + writer = ParquetWriter( |
| 749 | + nodes_dest=dest, |
| 750 | + relationships_dest=dest, |
| 751 | + collision_handler=FilenameCollisionHandler(), |
| 752 | + ) |
| 753 | + |
| 754 | + node1 = Neo4jNode( |
| 755 | + id="p1", label="Patient", properties={"name": "John", "age": "45"} |
| 756 | + ) |
| 757 | + node2 = Neo4jNode( |
| 758 | + id="p2", label="Patient", properties={"name": "Jane", "age": 30} |
| 759 | + ) |
| 760 | + graph = Neo4jGraph(nodes=[node1, node2], relationships=[]) |
| 761 | + |
| 762 | + result = await writer.run(graph=graph) |
| 763 | + |
| 764 | + assert result.status == "SUCCESS" |
| 765 | + table = pq.read_table(out / "Patient.parquet") |
| 766 | + assert table.num_rows == 2 |
| 767 | + # Both ages should have been coerced to str |
| 768 | + ages = {v.as_py() for v in table.column("age")} |
| 769 | + assert ages == {"45", "30"} |
0 commit comments