Skip to content

Commit 7862e2f

Browse files
zxqfd555Manul from Pathway
authored andcommitted
preserve properties when inheriting schemas (#9435)
GitOrigin-RevId: 9e5f19481e5558041d2f09fa2891206927e78b33
1 parent 76319c4 commit 7862e2f

File tree

3 files changed

+86
-5
lines changed

3 files changed

+86
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
77

88
### Fixed
99
- Endpoints created by `pw.io.http.rest_connector` now accept requests both with and without a trailing slash. For example, `/endpoint/` and `/endpoint` are now treated equivalently.
10+
- Schemas that inherit from other schemas now automatically preserve all properties from their parent schemas.
1011

1112
## [0.26.4] - 2025-10-16
1213

python/pathway/internals/schema.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,21 @@ def schema_add(*schemas: type[Schema]) -> type[Schema]:
178178

179179

180180
def _create_column_definitions(
181-
schema: SchemaMetaclass, schema_properties: SchemaProperties
181+
schema: SchemaMetaclass,
182+
bases: tuple[type],
183+
schema_properties: SchemaProperties,
182184
) -> dict[str, ColumnSchema]:
183185
localns = locals()
184186
# Update locals to handle recursive Schema definitions
185187
localns[schema.__name__] = schema
186188
annotations = get_type_hints(schema, localns=localns)
187189
fields = _cls_fields(schema)
190+
for base in bases:
191+
if not isinstance(base, SchemaMetaclass):
192+
continue
193+
for column_name, column_schema in base.__columns__.items():
194+
if column_name not in fields:
195+
fields[column_name] = column_schema.to_definition()
188196

189197
columns = {}
190198

@@ -274,15 +282,17 @@ class SchemaMetaclass(type):
274282
@trace.trace_user_frame
275283
def __init__(
276284
self,
277-
*args,
285+
name: str,
286+
bases: tuple[type],
287+
namespace: dict[str, Any],
278288
append_only: bool | None = None,
279289
id_dtype: dt.DType = dt.ANY_POINTER,
280290
id_append_only: bool | None = None,
281291
**kwargs,
282292
) -> None:
283-
super().__init__(*args, **kwargs)
293+
super().__init__(name, bases, namespace, **kwargs)
284294
schema_properties = SchemaProperties(append_only=append_only)
285-
self.__columns__ = _create_column_definitions(self, schema_properties)
295+
self.__columns__ = _create_column_definitions(self, bases, schema_properties)
286296
pk_dtypes = [col.dtype for col in self.__columns__.values() if col.primary_key]
287297
if len(pk_dtypes) > 0:
288298
derived_type = dt.Pointer(*pk_dtypes)
@@ -588,6 +598,7 @@ def assert_matches_schema(
588598
allow_superset: bool = True,
589599
ignore_primary_keys: bool = True,
590600
allow_subtype: bool = True,
601+
ignore_properties: bool = True,
591602
) -> None:
592603
self_dict = self._dtypes()
593604
other_dict = other._dtypes()
@@ -619,6 +630,12 @@ def assert_matches_schema(
619630
f"primary keys in the schemas do not match - they are {self.primary_key_columns()} in {self.__name__}",
620631
f" and {other.primary_key_columns()} in {other.__name__}",
621632
)
633+
if not ignore_properties:
634+
self_columns = self.columns()
635+
other_columns = other.columns()
636+
for column_name, column_schema in self_columns.items():
637+
other_column_schema = other_columns[column_name]
638+
assert column_schema == other_column_schema
622639

623640

624641
def _schema_builder(

python/pathway/tests/test_schema.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
from typing import Any
1111

1212
import numpy as np
13+
import pandas as pd
1314
import pytest
1415
from dateutil import tz
1516

1617
import pathway as pw
1718
from pathway.internals.schema import Schema
18-
from pathway.tests.utils import write_csv
19+
from pathway.tests.utils import write_csv, write_lines
1920

2021

2122
def assert_same_schema(left: type[Schema], right: type[Schema]):
@@ -433,3 +434,65 @@ class InputSchema(pw.Schema):
433434
DeserializedSchema.to_json_serializable_dict(), sort_keys=True
434435
)
435436
assert serialized_schema_json == roundtrip_schema_json
437+
438+
439+
@pytest.mark.parametrize("option", ["inheritance_1", "inheritance_2", "disjunction"])
440+
def test_schemas_composition(option, tmp_path):
441+
input_path = tmp_path / "input.jsonl"
442+
443+
class BaseSchema_1(pw.Schema):
444+
key: int = pw.column_definition(primary_key=True)
445+
value: str = pw.column_definition(
446+
default_value="default_value",
447+
description="test description",
448+
example="some example",
449+
)
450+
451+
class BaseSchema_2(pw.Schema):
452+
value2: bool
453+
454+
if option == "inheritance_1":
455+
456+
class InputSchema(BaseSchema_1):
457+
value2: bool
458+
459+
elif option == "inheritance_2":
460+
461+
class InputSchema(BaseSchema_1, BaseSchema_2):
462+
pass
463+
464+
elif option == "disjunction":
465+
InputSchema = BaseSchema_1 | BaseSchema_2
466+
else:
467+
raise ValueError("unexpected option: {option}")
468+
469+
class ExpectedSchema(pw.Schema):
470+
key: int = pw.column_definition(primary_key=True)
471+
value: str = pw.column_definition(
472+
default_value="default_value",
473+
description="test description",
474+
example="some example",
475+
)
476+
value2: bool
477+
478+
InputSchema.assert_matches_schema(
479+
ExpectedSchema,
480+
allow_subtype=False,
481+
allow_superset=False,
482+
ignore_primary_keys=False,
483+
ignore_properties=False,
484+
)
485+
486+
input_data = """
487+
{"key": 0, "value": "hello", "value2": true}
488+
{"key": 1, "value2": false}
489+
"""
490+
write_lines(input_path, input_data)
491+
492+
table = pw.io.jsonlines.read(input_path, mode="static", schema=InputSchema)
493+
494+
output_table = pw.debug.table_to_pandas(table)
495+
row_0 = output_table.loc[output_table["key"] == 0, ["value", "value2"]].iloc[0]
496+
assert (row_0 == pd.Series({"value": "hello", "value2": True})).all()
497+
row_1 = output_table.loc[output_table["key"] == 1, ["value", "value2"]].iloc[0]
498+
assert (row_1 == pd.Series({"value": "default_value", "value2": False})).all()

0 commit comments

Comments
 (0)