Skip to content

Commit bd1f5b9

Browse files
committed
ref: added some test and refactor of the extension_registry test in multiple file
1 parent 6a81e20 commit bd1f5b9

File tree

8 files changed

+1339
-892
lines changed

8 files changed

+1339
-892
lines changed

tests/extension_registry/__init__.py

Whitespace-only changes.
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
"""Tests for parsing of a registry yaml and basic registry operations (lookup, registration)."""
2+
3+
from substrait.builders.type import i8, i16
4+
from substrait.builders.type import (
5+
decimal,
6+
i8,
7+
i16,
8+
i32,
9+
struct,
10+
)
11+
from substrait.gen.proto.type_pb2 import Type
12+
13+
import pytest
14+
import yaml
15+
16+
from substrait.extension_registry import ExtensionRegistry
17+
18+
# Common test YAML content for testing basic functions
19+
CONTENT = """%YAML 1.2
20+
---
21+
urn: extension:test:functions
22+
scalar_functions:
23+
- name: "test_fn"
24+
description: ""
25+
impls:
26+
- args:
27+
- value: i8
28+
variadic:
29+
min: 2
30+
return: i8
31+
- name: "test_fn_variadic_any"
32+
description: ""
33+
impls:
34+
- args:
35+
- value: any1
36+
variadic:
37+
min: 2
38+
return: any1
39+
- name: "add"
40+
description: "Add two values."
41+
impls:
42+
- args:
43+
- name: x
44+
value: i8
45+
- name: y
46+
value: i8
47+
options:
48+
overflow:
49+
values: [ SILENT, SATURATE, ERROR ]
50+
return: i8
51+
- args:
52+
- name: x
53+
value: i8
54+
- name: y
55+
value: i8
56+
- name: z
57+
value: any
58+
options:
59+
overflow:
60+
values: [ SILENT, SATURATE, ERROR ]
61+
return: i16
62+
- args:
63+
- name: x
64+
value: any1
65+
- name: y
66+
value: any1
67+
- name: z
68+
value: any2
69+
options:
70+
overflow:
71+
values: [ SILENT, SATURATE, ERROR ]
72+
return: any2
73+
- name: "test_decimal"
74+
impls:
75+
- args:
76+
- name: x
77+
value: decimal<P1,S1>
78+
- name: y
79+
value: decimal<S1,S2>
80+
return: decimal<P1 + 1,S2 + 1>
81+
- name: "test_enum"
82+
impls:
83+
- args:
84+
- name: op
85+
options: [ INTACT, FLIP ]
86+
- name: x
87+
value: i8
88+
return: i8
89+
- name: "add_declared"
90+
description: "Add two values."
91+
impls:
92+
- args:
93+
- name: x
94+
value: i8
95+
- name: y
96+
value: i8
97+
nullability: DECLARED_OUTPUT
98+
return: i8?
99+
- name: "add_discrete"
100+
description: "Add two values."
101+
impls:
102+
- args:
103+
- name: x
104+
value: i8?
105+
- name: y
106+
value: i8
107+
nullability: DISCRETE
108+
return: i8?
109+
- name: "test_decimal_discrete"
110+
impls:
111+
- args:
112+
- name: x
113+
value: decimal?<P1,S1>
114+
- name: y
115+
value: decimal<S1,S2>
116+
nullability: DISCRETE
117+
return: decimal?<P1 + 1,S2 + 1>
118+
- name: "equal_test"
119+
impls:
120+
- args:
121+
- name: x
122+
value: any
123+
- name: y
124+
value: any
125+
nullability: DISCRETE
126+
return: any
127+
"""
128+
129+
130+
@pytest.fixture(scope="session")
131+
def registry():
132+
"""Create a registry with test functions loaded."""
133+
reg = ExtensionRegistry(load_default_extensions=True)
134+
reg.register_extension_dict(
135+
yaml.safe_load(CONTENT),
136+
uri="https://test.example.com/extension_test_functions.yaml",
137+
)
138+
return reg
139+
140+
141+
142+
143+
def test_non_existing_urn(registry):
144+
assert (
145+
registry.lookup_function(
146+
urn="non_existent",
147+
function_name="add",
148+
signature=[i8(nullable=False), i8(nullable=False)],
149+
)
150+
is None
151+
)
152+
153+
154+
def test_non_existing_function(registry):
155+
assert (
156+
registry.lookup_function(
157+
urn="extension:test:functions",
158+
function_name="sub",
159+
signature=[i8(nullable=False), i8(nullable=False)],
160+
)
161+
is None
162+
)
163+
164+
165+
def test_non_existing_function_signature(registry):
166+
assert (
167+
registry.lookup_function(
168+
urn="extension:test:functions",
169+
function_name="add",
170+
signature=[i8(nullable=False)],
171+
)
172+
is None
173+
)
174+
175+
176+
def test_exact_match(registry):
177+
assert registry.lookup_function(
178+
urn="extension:test:functions",
179+
function_name="add",
180+
signature=[i8(nullable=False), i8(nullable=False)],
181+
)[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED))
182+
183+
184+
def test_wildcard_match(registry):
185+
assert registry.lookup_function(
186+
urn="extension:test:functions",
187+
function_name="add",
188+
signature=[i8(nullable=False), i8(nullable=False), bool()],
189+
)[1] == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED))
190+
191+
192+
def test_wildcard_match_fails_with_constraits(registry):
193+
assert (
194+
registry.lookup_function(
195+
urn="extension:test:functions",
196+
function_name="add",
197+
signature=[i8(nullable=False), i16(nullable=False), i16(nullable=False)],
198+
)
199+
is None
200+
)
201+
202+
203+
def test_wildcard_match_with_constraits(registry):
204+
assert registry.lookup_function(
205+
urn="extension:test:functions",
206+
function_name="add",
207+
signature=[i16(nullable=False), i16(nullable=False), i8(nullable=False)],
208+
)[1] == i8(nullable=False)
209+
210+
211+
def test_variadic(registry):
212+
assert registry.lookup_function(
213+
urn="extension:test:functions",
214+
function_name="test_fn",
215+
signature=[i8(nullable=False), i8(nullable=False), i8(nullable=False)],
216+
)[1] == i8(nullable=False)
217+
218+
219+
def test_variadic_any(registry):
220+
assert registry.lookup_function(
221+
urn="extension:test:functions",
222+
function_name="test_fn_variadic_any",
223+
signature=[i16(nullable=False), i16(nullable=False), i16(nullable=False)],
224+
)[1] == i16(nullable=False)
225+
226+
227+
def test_variadic_fails_min_constraint(registry):
228+
assert (
229+
registry.lookup_function(
230+
urn="extension:test:functions",
231+
function_name="test_fn",
232+
signature=[i8(nullable=False)],
233+
)
234+
is None
235+
)
236+
237+
238+
def test_decimal_happy_path(registry):
239+
assert registry.lookup_function(
240+
urn="extension:test:functions",
241+
function_name="test_decimal",
242+
signature=[decimal(8, 10, nullable=False), decimal(6, 8, nullable=False)],
243+
)[1] == decimal(7, 11, nullable=False)
244+
245+
246+
def test_decimal_violates_constraint(registry):
247+
assert (
248+
registry.lookup_function(
249+
urn="extension:test:functions",
250+
function_name="test_decimal",
251+
signature=[decimal(8, 10, nullable=False), decimal(10, 12, nullable=False)],
252+
)
253+
is None
254+
)
255+
256+
257+
def test_decimal_happy_path_discrete(registry):
258+
assert registry.lookup_function(
259+
urn="extension:test:functions",
260+
function_name="test_decimal_discrete",
261+
signature=[decimal(8, 10, nullable=True), decimal(6, 8, nullable=False)],
262+
)[1] == decimal(7, 11, nullable=True)
263+
264+
265+
def test_enum_with_valid_option(registry):
266+
assert registry.lookup_function(
267+
urn="extension:test:functions",
268+
function_name="test_enum",
269+
signature=["FLIP", i8(nullable=False)],
270+
)[1] == i8(nullable=False)
271+
272+
273+
def test_enum_with_nonexistent_option(registry):
274+
assert (
275+
registry.lookup_function(
276+
urn="extension:test:functions",
277+
function_name="test_enum",
278+
signature=["NONEXISTENT", i8(nullable=False)],
279+
)
280+
is None
281+
)
282+
283+
284+
def test_function_with_nullable_args(registry):
285+
assert registry.lookup_function(
286+
urn="extension:test:functions",
287+
function_name="add",
288+
signature=[i8(nullable=True), i8(nullable=False)],
289+
)[1] == i8(nullable=True)
290+
291+
292+
def test_function_with_declared_output_nullability(registry):
293+
assert registry.lookup_function(
294+
urn="extension:test:functions",
295+
function_name="add_declared",
296+
signature=[i8(nullable=False), i8(nullable=False)],
297+
)[1] == i8(nullable=True)
298+
299+
300+
def test_function_with_discrete_nullability(registry):
301+
assert registry.lookup_function(
302+
urn="extension:test:functions",
303+
function_name="add_discrete",
304+
signature=[i8(nullable=True), i8(nullable=False)],
305+
)[1] == i8(nullable=True)
306+
307+
308+
def test_function_with_discrete_nullability_nonexisting(registry):
309+
assert (
310+
registry.lookup_function(
311+
urn="extension:test:functions",
312+
function_name="add_discrete",
313+
signature=[i8(nullable=False), i8(nullable=False)],
314+
)
315+
is None
316+
)

0 commit comments

Comments
 (0)