Skip to content

Commit 2bcca48

Browse files
authored
fix(parquetio): handle missing nullable fields in row conversion (#35948)
* fix(parquetio): handle missing nullable fields in row conversion Add null value handling when converting rows to Arrow tables for nullable fields that are missing from input data. This fixes KeyError when writing to Parquet with missing nullable fields, addressing issue #35791. * fix lint
1 parent 62cbf83 commit 2bcca48

File tree

2 files changed

+78
-4
lines changed

2 files changed

+78
-4
lines changed

sdks/python/apache_beam/io/parquetio.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,12 @@ def process(self, row, w=DoFn.WindowParam, pane=DoFn.PaneInfoParam):
119119

120120
# reorder the data in columnar format.
121121
for i, n in enumerate(self._schema.names):
122-
self._buffer[i].append(row[n])
122+
# Handle missing nullable fields by using None as default value
123+
field = self._schema.field(i)
124+
if field.nullable and n not in row:
125+
self._buffer[i].append(None)
126+
else:
127+
self._buffer[i].append(row[n])
123128

124129
def finish_bundle(self):
125130
if len(self._buffer[0]) > 0:

sdks/python/apache_beam/io/parquetio_test.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,11 @@
5959
try:
6060
import pyarrow as pa
6161
import pyarrow.parquet as pq
62+
ARROW_MAJOR_VERSION, _, _ = map(int, pa.__version__.split('.'))
6263
except ImportError:
6364
pa = None
64-
pl = None
6565
pq = None
66-
67-
ARROW_MAJOR_VERSION, _, _ = map(int, pa.__version__.split('.'))
66+
ARROW_MAJOR_VERSION = 0
6867

6968

7069
@unittest.skipIf(pa is None, "PyArrow is not installed.")
@@ -422,6 +421,76 @@ def test_schema_read_write(self):
422421
| Map(stable_repr))
423422
assert_that(readback, equal_to([stable_repr(r) for r in rows]))
424423

424+
def test_write_with_nullable_fields_missing_data(self):
425+
"""Test WriteToParquet with nullable fields where some fields are missing.
426+
427+
This test addresses the bug reported in:
428+
https://github.com/apache/beam/issues/35791
429+
where WriteToParquet fails with a KeyError if any nullable
430+
field is missing in the data.
431+
"""
432+
# Define PyArrow schema with all fields nullable
433+
schema = pa.schema([
434+
pa.field("id", pa.int64(), nullable=True),
435+
pa.field("name", pa.string(), nullable=True),
436+
pa.field("age", pa.int64(), nullable=True),
437+
pa.field("email", pa.string(), nullable=True),
438+
])
439+
440+
# Sample data with missing nullable fields
441+
data = [
442+
{
443+
'id': 1, 'name': 'Alice', 'age': 30
444+
}, # missing 'email'
445+
{
446+
'id': 2, 'name': 'Bob', 'age': 25, 'email': 'bob@example.com'
447+
}, # all fields present
448+
{
449+
'id': 3, 'name': 'Charlie', 'age': None, 'email': None
450+
}, # explicit None values
451+
{
452+
'id': 4, 'name': 'David'
453+
}, # missing 'age' and 'email'
454+
]
455+
456+
with TemporaryDirectory() as tmp_dirname:
457+
path = os.path.join(tmp_dirname, 'nullable_test')
458+
459+
# Write data with missing nullable fields - this should not raise KeyError
460+
with TestPipeline() as p:
461+
_ = (
462+
p
463+
| Create(data)
464+
| WriteToParquet(
465+
path, schema, num_shards=1, shard_name_template=''))
466+
467+
# Read back and verify the data
468+
with TestPipeline() as p:
469+
readback = (
470+
p
471+
| ReadFromParquet(path + '*')
472+
| Map(json.dumps, sort_keys=True))
473+
474+
# Expected data should have None for missing nullable fields
475+
expected_data = [
476+
{
477+
'id': 1, 'name': 'Alice', 'age': 30, 'email': None
478+
},
479+
{
480+
'id': 2, 'name': 'Bob', 'age': 25, 'email': 'bob@example.com'
481+
},
482+
{
483+
'id': 3, 'name': 'Charlie', 'age': None, 'email': None
484+
},
485+
{
486+
'id': 4, 'name': 'David', 'age': None, 'email': None
487+
},
488+
]
489+
490+
assert_that(
491+
readback,
492+
equal_to([json.dumps(r, sort_keys=True) for r in expected_data]))
493+
425494
def test_batched_read(self):
426495
with TemporaryDirectory() as tmp_dirname:
427496
path = os.path.join(tmp_dirname + "tmp_filename")

0 commit comments

Comments
 (0)