Skip to content

Commit eda153c

Browse files
Update with_properties to enable changing existing properties on ColumnSchema (#157)
* Update `ColumnSchema.with_properties` to change properties correctly. * Add tests for `ColumnSchema` equality and update methods `with_*`
1 parent 2c621a2 commit eda153c

File tree

2 files changed

+132
-21
lines changed

2 files changed

+132
-21
lines changed

merlin/schema/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,12 @@ def with_properties(self, properties: dict) -> "ColumnSchema":
172172
raise TypeError("properties must be in dict format, key: value")
173173

174174
# Using new dictionary to avoid passing old ref to new schema
175-
properties.update(self.properties)
175+
new_properties = {**self.properties, **properties}
176176

177177
return ColumnSchema(
178178
self.name,
179179
tags=self.tags,
180-
properties=properties,
180+
properties=new_properties,
181181
dtype=self.dtype,
182182
is_list=self.is_list,
183183
is_ragged=self.is_ragged,

tests/unit/schema/test_column_schemas.py

Lines changed: 130 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,136 @@ def test_dtype_column_schema(d_types):
2727
assert column.dtype == d_types
2828

2929

30-
def test_column_schema_meta():
31-
column = ColumnSchema("name", tags=["tag-1"], properties={"p1": "prop-1"})
32-
33-
assert column.name == "name"
34-
assert "tag-1" in column.tags
35-
assert column.with_name("a").name == "a"
36-
assert set(column.with_tags("tag-2").tags) == set(["tag-1", "tag-2"])
37-
assert column.with_properties({"p2": "prop-2"}).properties == {
38-
"p1": "prop-1",
39-
"p2": "prop-2",
40-
}
41-
assert column.with_tags("tag-2").properties == {"p1": "prop-1"}
42-
assert set(column.with_properties({"p2": "prop-2"}).tags) == set(["tag-1"])
43-
44-
assert column == ColumnSchema("name", tags=["tag-1"], properties={"p1": "prop-1"})
45-
# should not be the same no properties
46-
assert column != ColumnSchema("name", tags=["tag-1"])
47-
# should not be the same no tags
48-
assert column != ColumnSchema("name", properties={"p1": "prop-1"})
30+
@pytest.mark.parametrize(
31+
["column_schema_a", "column_schema_b"],
32+
[
33+
[ColumnSchema("col"), ColumnSchema("col")],
34+
[ColumnSchema("col_b", tags=["tag-1"]), ColumnSchema("col_b", tags=["tag-1"])],
35+
[
36+
ColumnSchema("col", dtype=numpy.int32, properties={"domain": {"min": 0, "max": 8}}),
37+
ColumnSchema("col", dtype=numpy.int32, properties={"domain": {"min": 0, "max": 8}}),
38+
],
39+
[
40+
ColumnSchema(
41+
"col",
42+
dtype=numpy.float32,
43+
tags=["tag-2", Tags.CONTINUOUS],
44+
properties={"p1": "prop-1"},
45+
),
46+
ColumnSchema(
47+
"col",
48+
dtype=numpy.float32,
49+
tags=["tag-2", Tags.CONTINUOUS],
50+
properties={"p1": "prop-1"},
51+
),
52+
],
53+
],
54+
)
55+
def test_equal(column_schema_a, column_schema_b):
56+
assert column_schema_a == column_schema_b
57+
assert column_schema_a.name == column_schema_b.name
58+
assert column_schema_a.dtype == column_schema_b.dtype
59+
assert column_schema_a.tags == column_schema_b.tags
60+
assert column_schema_a.properties == column_schema_b.properties
61+
62+
63+
@pytest.mark.parametrize(
64+
["column_schema_a", "column_schema_b"],
65+
[
66+
[ColumnSchema("col_a"), ColumnSchema("col_b")],
67+
[ColumnSchema("name"), ColumnSchema("name", tags=["tags-1"])],
68+
[ColumnSchema("name"), ColumnSchema("name", properties={"p1": "prop-1"})],
69+
[
70+
ColumnSchema("name", tags=["tag-1"]),
71+
ColumnSchema("name", properties={"p1": "prop-1"}),
72+
],
73+
[
74+
ColumnSchema("name", tags=["tag-1"], properties={"p1": "prop-1"}),
75+
ColumnSchema("name", properties={"p1": "prop-1"}),
76+
],
77+
],
78+
)
79+
def test_not_equal(column_schema_a, column_schema_b):
80+
assert column_schema_a != column_schema_b
81+
82+
83+
@pytest.mark.parametrize(
84+
["column_schema", "name", "expected_column_schema"],
85+
[
86+
[ColumnSchema("col_a"), "col_b", ColumnSchema("col_b")],
87+
[ColumnSchema("feat", tags=["tag-1"]), "seq", ColumnSchema("seq", tags=["tag-1"])],
88+
[
89+
ColumnSchema(
90+
"feat",
91+
tags=["tag-1"],
92+
dtype=numpy.float32,
93+
properties={"domain": {"min": 0.0, "max": 6.0}},
94+
),
95+
"feat_b",
96+
ColumnSchema(
97+
"feat_b",
98+
tags=["tag-1"],
99+
dtype=numpy.float32,
100+
properties={"domain": {"min": 0.0, "max": 6.0}},
101+
),
102+
],
103+
],
104+
)
105+
def test_with_name(column_schema, name, expected_column_schema):
106+
assert column_schema.with_name(name) == expected_column_schema
107+
108+
109+
@pytest.mark.parametrize(
110+
["column_schema", "tags", "expected_column_schema"],
111+
[
112+
[
113+
ColumnSchema("example", tags=["tag-1"], properties={"p1": "prop-1"}),
114+
"tag-2",
115+
ColumnSchema("example", tags=["tag-1", "tag-2"], properties={"p1": "prop-1"}),
116+
],
117+
[
118+
ColumnSchema("example", tags=["tag-1"], dtype=numpy.float32),
119+
["tag-2", Tags.CONTINUOUS],
120+
ColumnSchema("example", tags=["tag-1", "tag-2", Tags.CONTINUOUS], dtype=numpy.float32),
121+
],
122+
],
123+
)
124+
def test_with_tags(column_schema, tags, expected_column_schema):
125+
assert column_schema.with_tags(tags) == expected_column_schema
126+
127+
128+
@pytest.mark.parametrize(
129+
["column_schema", "properties", "expected_column_schema"],
130+
[
131+
[
132+
ColumnSchema("example", properties={"a": "old"}),
133+
{"a": "new"},
134+
ColumnSchema("example", properties={"a": "new"}),
135+
],
136+
[
137+
ColumnSchema("example", properties={"a": 1, "b": 2}),
138+
{"a": 4, "c": 3},
139+
ColumnSchema("example", properties={"a": 4, "b": 2, "c": 3}),
140+
],
141+
[
142+
ColumnSchema(
143+
"example_col_2",
144+
dtype=numpy.float32,
145+
tags=[Tags.CONTINUOUS],
146+
properties={"a": 1, "domain": {"min": 0, "max": 5}},
147+
),
148+
{"a": 4, "c": 3, "domain": {"max": 8}},
149+
ColumnSchema(
150+
"example_col_2",
151+
dtype=numpy.float32,
152+
tags=[Tags.CONTINUOUS],
153+
properties={"a": 4, "c": 3, "domain": {"max": 8}},
154+
),
155+
],
156+
],
157+
)
158+
def test_with_properties(column_schema, properties, expected_column_schema):
159+
assert column_schema.with_properties(properties) == expected_column_schema
49160

50161

51162
def test_column_schema_tags_normalize():

0 commit comments

Comments
 (0)