Skip to content

Commit e6d5434

Browse files
committed
fix: build explicit PyArrow schema in Neo4jGraphParquetFormatter to prevent silent embedding column drops
1 parent 3d092f6 commit e6d5434

File tree

2 files changed

+266
-3
lines changed

2 files changed

+266
-3
lines changed

src/neo4j_graphrag/experimental/components/parquet_formatter.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,13 +481,38 @@ def format_parquet(
481481
import pyarrow.parquet as pq
482482

483483
self._normalize_column_types(rows)
484-
table = pa.Table.from_pylist(rows)
485-
# Write to BytesIO buffer
484+
485+
# Build an explicit schema from the union of all keys to avoid
486+
# silent column drops when the first row lacks some keys (e.g. embeddings).
487+
# Dict preserves first-seen insertion order for deterministic column ordering.
488+
all_keys: dict[str, None] = {k: None for row in rows for k in row}
489+
# First non-null value for each key, used to infer the column type.
490+
sample: dict[str, Any] = {}
491+
for row in rows:
492+
for k, v in row.items():
493+
if k not in sample and v is not None:
494+
sample[k] = v
495+
496+
fields: list[Any] = []
497+
for k in all_keys:
498+
if k in sample:
499+
t: Any = pa.infer_type([sample[k]])
500+
if pa.types.is_list(t) and pa.types.is_floating(t.value_type):
501+
t = pa.list_(pa.float32())
502+
else:
503+
t = pa.null()
504+
fields.append(pa.field(k, t))
505+
506+
schema = pa.schema(fields) if fields else None
507+
table = pa.Table.from_pylist(rows, schema=schema)
508+
486509
buffer = BytesIO()
487510
pq.write_table(table, buffer)
488511
buffer.seek(0)
489512
return buffer.read(), table.schema
490-
except (ValueError, TypeError) as e:
513+
except ImportError:
514+
raise
515+
except (ValueError, TypeError, pa.ArrowInvalid) as e:
491516
raise ValueError(
492517
f"Failed to create Parquet table for {entity_name}: {e}"
493518
) from e

tests/unit/experimental/components/test_kg_writer.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import tempfile
18+
from io import BytesIO
1819
from pathlib import Path
1920
from typing import Any
2021
from unittest import mock
@@ -767,3 +768,240 @@ async def test_parquet_writer_mixed_property_types() -> None:
767768
# Both ages should have been coerced to str
768769
ages = {v.as_py() for v in table.column("age")}
769770
assert ages == {"45", "30"}
771+
772+
773+
# ---------------------------------------------------------------------------
774+
# Regression tests: node embedding column must be present regardless of row order
775+
# ---------------------------------------------------------------------------
776+
777+
778+
@pytest.mark.parametrize(
779+
"failed_first",
780+
[
781+
pytest.param(True, id="failed_batch_first"),
782+
pytest.param(False, id="succeeded_batch_first"),
783+
],
784+
)
785+
def test_node_embedding_column_present_regardless_of_row_order(
786+
failed_first: bool,
787+
) -> None:
788+
"""Embedding column must exist in the Parquet table regardless of which rows come first.
789+
790+
Regression test for the bug where failed-batch nodes (empty embedding_properties)
791+
appearing before succeeded-batch nodes caused PyArrow to omit the embedding column
792+
entirely from the inferred schema.
793+
"""
794+
pytest.importorskip("pyarrow")
795+
import pyarrow as pa
796+
import pyarrow.parquet as pq
797+
798+
formatter = Neo4jGraphParquetFormatter()
799+
800+
# Rows simulating two batches of the same node label:
801+
# - failed-batch row: no embedding key (as if embedding_properties was empty)
802+
# - succeeded-batch row: embedding key present
803+
failed_row: dict[str, Any] = {
804+
"__id__": "node-1",
805+
"name": "Alice",
806+
"labels": ["Person", "__Entity__"],
807+
}
808+
succeeded_row: dict[str, Any] = {
809+
"__id__": "node-2",
810+
"name": "Bob",
811+
"labels": ["Person", "__Entity__"],
812+
"embedding": [0.1, 0.2, 0.3],
813+
}
814+
815+
rows = [failed_row, succeeded_row] if failed_first else [succeeded_row, failed_row]
816+
817+
parquet_bytes, schema = formatter.format_parquet(rows, "node label 'Person'")
818+
819+
# The embedding column must always be present in the schema
820+
assert "embedding" in schema.names, (
821+
f"'embedding' column missing from schema when failed_first={failed_first}. "
822+
f"Schema columns: {schema.names}"
823+
)
824+
825+
# Read back the table and verify nulls and types
826+
table = pq.read_table(BytesIO(parquet_bytes))
827+
assert "embedding" in table.column_names
828+
829+
# The row without an embedding should have a null value
830+
rows_as_dicts = table.to_pylist()
831+
rows_by_id = {r["__id__"]: r for r in rows_as_dicts}
832+
assert (
833+
rows_by_id["node-1"]["embedding"] is None
834+
), "Row without embedding should have null value in the embedding column"
835+
assert (
836+
rows_by_id["node-2"]["embedding"] is not None
837+
), "Row with embedding should have a non-null value in the embedding column"
838+
839+
# The embedding field type must be a list of floats (variable or fixed-size)
840+
emb_field = schema.field("embedding")
841+
emb_type = emb_field.type
842+
# Because node-1 has null, the formatter must fall back to list_(float32)
843+
assert pa.types.is_list(emb_type) or pa.types.is_fixed_size_list(
844+
emb_type
845+
), f"Unexpected embedding field type: {emb_type}"
846+
# The value type must be float32
847+
assert (
848+
emb_type.value_type == pa.float32()
849+
), f"Embedding value type should be float32, got {emb_type.value_type}"
850+
851+
852+
# ---------------------------------------------------------------------------
853+
# Regression tests: relationship embedding column must be present regardless of row order
854+
# ---------------------------------------------------------------------------
855+
856+
857+
@pytest.mark.parametrize(
858+
"failed_first",
859+
[
860+
pytest.param(True, id="failed_batch_first"),
861+
pytest.param(False, id="succeeded_batch_first"),
862+
],
863+
)
864+
def test_relationship_embedding_column_present_regardless_of_row_order(
865+
failed_first: bool,
866+
) -> None:
867+
"""Embedding column must exist in the relationship Parquet table regardless of which rows come first.
868+
869+
Regression test for the bug where failed-batch relationships (empty embedding_properties)
870+
appearing before succeeded-batch relationships caused PyArrow to omit the embedding column
871+
entirely from the inferred schema.
872+
"""
873+
pytest.importorskip("pyarrow")
874+
import pyarrow as pa
875+
import pyarrow.parquet as pq
876+
877+
formatter = Neo4jGraphParquetFormatter()
878+
879+
# Rows simulating two batches of the same relationship type:
880+
# - failed-batch row: no embedding key (as if embedding_properties was empty)
881+
# - succeeded-batch row: embedding key present
882+
failed_row: dict[str, Any] = {
883+
"from": "node-1",
884+
"to": "node-2",
885+
"from_label": "Person",
886+
"to_label": "Person",
887+
"type": "KNOWS",
888+
"since": "2020",
889+
}
890+
succeeded_row: dict[str, Any] = {
891+
"from": "node-3",
892+
"to": "node-4",
893+
"from_label": "Person",
894+
"to_label": "Person",
895+
"type": "KNOWS",
896+
"since": "2021",
897+
"embedding": [0.1, 0.2, 0.3],
898+
}
899+
900+
rows = [failed_row, succeeded_row] if failed_first else [succeeded_row, failed_row]
901+
902+
parquet_bytes, schema = formatter.format_parquet(
903+
rows, "relationship 'Person_KNOWS_Person'"
904+
)
905+
906+
# The embedding column must always be present in the schema
907+
assert "embedding" in schema.names, (
908+
f"'embedding' column missing from relationship schema when failed_first={failed_first}. "
909+
f"Schema columns: {schema.names}"
910+
)
911+
912+
# Read back the table and verify nulls and types
913+
table = pq.read_table(BytesIO(parquet_bytes))
914+
assert "embedding" in table.column_names
915+
916+
# The row without an embedding should have a null value
917+
rows_as_dicts = table.to_pylist()
918+
rows_by_from = {r["from"]: r for r in rows_as_dicts}
919+
assert (
920+
rows_by_from["node-1"]["embedding"] is None
921+
), "Relationship row without embedding should have null value in the embedding column"
922+
assert (
923+
rows_by_from["node-3"]["embedding"] is not None
924+
), "Relationship row with embedding should have a non-null value in the embedding column"
925+
926+
# The embedding field type must be a list of floats (variable or fixed-size)
927+
emb_field = schema.field("embedding")
928+
emb_type = emb_field.type
929+
# Because the failed row has null, the formatter must fall back to list_(float32)
930+
assert pa.types.is_list(emb_type) or pa.types.is_fixed_size_list(
931+
emb_type
932+
), f"Unexpected relationship embedding field type: {emb_type}"
933+
# The value type must be float32
934+
assert (
935+
emb_type.value_type == pa.float32()
936+
), f"Relationship embedding value type should be float32, got {emb_type.value_type}"
937+
938+
939+
# ---------------------------------------------------------------------------
940+
# Degenerate case: all rows lack the embedding key (all-null column path)
941+
# ---------------------------------------------------------------------------
942+
943+
944+
def test_format_parquet_all_rows_missing_embedding_does_not_crash() -> None:
945+
"""format_parquet must not raise when no row has an embedding key.
946+
947+
When every row lacks a given key the formatter falls back to pa.null() for
948+
that column's type. This test verifies that path doesn't crash and that
949+
the resulting table contains only the columns that were actually present.
950+
"""
951+
pytest.importorskip("pyarrow")
952+
import pyarrow.parquet as pq
953+
954+
formatter = Neo4jGraphParquetFormatter()
955+
956+
rows: list[dict[str, Any]] = [
957+
{"__id__": "node-1", "name": "Alice", "labels": ["Person"]},
958+
{"__id__": "node-2", "name": "Bob", "labels": ["Person"]},
959+
]
960+
961+
parquet_bytes, schema = formatter.format_parquet(rows, "node label 'Person'")
962+
963+
assert (
964+
"embedding" not in schema.names
965+
), "Embedding column should not appear when no row carries an embedding key"
966+
967+
table = pq.read_table(BytesIO(parquet_bytes))
968+
assert table.num_rows == 2
969+
assert set(table.column_names) == {"__id__", "name", "labels"}
970+
971+
972+
# ---------------------------------------------------------------------------
973+
# Edge case: all rows have an empty list for the embedding key (all-null path)
974+
# ---------------------------------------------------------------------------
975+
976+
977+
def test_format_parquet_all_rows_empty_list_embedding_does_not_crash() -> None:
978+
"""format_parquet must not raise when every row has an empty list for the embedding key.
979+
980+
When the sample dict filters out empty lists (they are falsy but not None, so
981+
they pass the `v is not None` guard), pa.infer_type([[]]) returns list<null>.
982+
This test verifies the resulting table survives a Parquet round-trip and that
983+
all embedding values are empty lists.
984+
"""
985+
pytest.importorskip("pyarrow")
986+
import pyarrow.parquet as pq
987+
988+
formatter = Neo4jGraphParquetFormatter()
989+
990+
rows: list[dict[str, Any]] = [
991+
{"__id__": "node-1", "name": "Alice", "labels": ["Person"], "embedding": []},
992+
{"__id__": "node-2", "name": "Bob", "labels": ["Person"], "embedding": []},
993+
]
994+
995+
parquet_bytes, schema = formatter.format_parquet(rows, "node label 'Person'")
996+
997+
assert (
998+
"embedding" in schema.names
999+
), "Embedding column should be present even when all rows have an empty list"
1000+
1001+
table = pq.read_table(BytesIO(parquet_bytes))
1002+
assert table.num_rows == 2
1003+
assert "embedding" in table.column_names
1004+
for row in table.to_pylist():
1005+
assert (
1006+
row["embedding"] == [] or row["embedding"] is None
1007+
), f"Expected empty list or null for embedding, got {row['embedding']}"

0 commit comments

Comments
 (0)