Skip to content

Commit 81bc167

Browse files
committed
Fix: Cover behavior of parametrized types (#4)
Fix: Cover behavior of parametrized types and increased the test coverage for extension_registry
1 parent 76b236d commit 81bc167

File tree

2 files changed

+226
-118
lines changed

2 files changed

+226
-118
lines changed

src/substrait/extension_registry.py

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,24 @@ def normalize_substrait_type_names(typ: str) -> str:
6868
raise Exception(f"Unrecognized substrait type {typ}")
6969

7070

71-
def violates_integer_option(actual: int, option, parameters: dict):
71+
def violates_integer_option(actual: int, option, parameters: dict, subset=False):
72+
option_numeric = None
7273
if isinstance(option, SubstraitTypeParser.NumericLiteralContext):
73-
return actual != int(str(option.Number()))
74+
option_numeric = int(str(option.Number()))
7475
elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext):
7576
parameter_name = str(option.Identifier())
76-
if parameter_name in parameters and parameters[parameter_name] != actual:
77-
return True
78-
else:
77+
78+
if parameter_name not in parameters:
7979
parameters[parameter_name] = actual
80+
option_numeric = parameters[parameter_name]
8081
else:
8182
raise Exception(
8283
f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead"
8384
)
84-
85-
return False
85+
if subset:
86+
return actual < option_numeric
87+
else:
88+
return actual != option_numeric
8689

8790

8891
def types_equal(type1: Type, type2: Type, check_nullability=False):
@@ -111,7 +114,8 @@ def handle_parameter_cover(
111114
parameters[parameter_name] = covered
112115
return True
113116

114-
def _check_nullability(check_nullability,parameterized_type,covered,kind) -> bool:
117+
118+
def _check_nullability(check_nullability, parameterized_type, covered, kind) -> bool:
115119
if not check_nullability:
116120
return True
117121
# The ANTLR context stores a Token called ``isnull`` – it is
@@ -121,6 +125,8 @@ def _check_nullability(check_nullability,parameterized_type,covered,kind) -> boo
121125
if getattr(parameterized_type, "isnull", None) is not None
122126
else Type.Nullability.NULLABILITY_REQUIRED
123127
)
128+
# if nullability == Type.Nullability.NULLABILITY_NULLABLE:
129+
# return True # is still true even if the covered is required
124130
# The protobuf message stores its own enum – we compare the two.
125131
covered_nullability = getattr(
126132
getattr(covered, kind), # e.g. covered.varchar
@@ -129,6 +135,7 @@ def _check_nullability(check_nullability,parameterized_type,covered,kind) -> boo
129135
)
130136
return nullability == covered_nullability
131137

138+
132139
def covers(
133140
covered: Type,
134141
covering: SubstraitTypeParser.TypeLiteralContext,
@@ -140,7 +147,6 @@ def covers(
140147
return handle_parameter_cover(
141148
covered, parameter_name, parameters, check_nullability
142149
)
143-
144150
covering: SubstraitTypeParser.TypeDefContext = covering.typeDef()
145151

146152
any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType()
@@ -163,77 +169,95 @@ def covers(
163169
if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext):
164170
if kind != "varchar":
165171
return False
166-
if getattr(covered.varchar, "length", 0) > getattr(parameterized_type, "length", 0):
172+
if hasattr(parameterized_type, "length") and violates_integer_option(
173+
covered.varchar.length, parameterized_type.length, parameters
174+
):
167175
return False
168-
return True
169-
return _check_nullability(check_nullability,parameterized_type,covered,kind)
170176

177+
return _check_nullability(
178+
check_nullability, parameterized_type, covered, kind
179+
)
171180
if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext):
172181
if kind != "fixed_char":
173182
return False
174-
if getattr(covered.fixed_char, "length", 0) > getattr(parameterized_type, "length", 0):
183+
if hasattr(parameterized_type, "length") and violates_integer_option(
184+
covered.fixed_char.length, parameterized_type.length, parameters
185+
):
175186
return False
176-
return True
177-
return _check_nullability(check_nullability,parameterized_type,covered,kind)
187+
return _check_nullability(
188+
check_nullability, parameterized_type, covered, kind
189+
)
178190

179191
if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext):
180192
if kind != "fixed_binary":
181193
return False
182-
if getattr(covered.fixed_binary, "length", 0) > getattr(parameterized_type, "length", 0):
194+
if hasattr(parameterized_type, "length") and violates_integer_option(
195+
covered.fixed_binary.length, parameterized_type.length, parameters
196+
):
183197
return False
184-
return True
185-
return _check_nullability(check_nullability,parameterized_type,covered,kind)
198+
# return True
199+
return _check_nullability(
200+
check_nullability, parameterized_type, covered, kind
201+
)
186202
if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
187203
if kind != "decimal":
188204
return False
189-
if not _check_nullability(check_nullability,parameterized_type,covered,kind):
205+
if not _check_nullability(
206+
check_nullability, parameterized_type, covered, kind
207+
):
190208
return False
191209
# precision / scale are both optional – a missing value means “no limit”.
192-
covered_scale = getattr(covered.decimal, "scale", 0)
193-
param_scale = getattr(parameterized_type, "scale", 0)
194-
covered_prec = getattr(covered.decimal, "precision", 0)
195-
param_prec = getattr(parameterized_type, "precision", 0)
210+
covered_scale = getattr(covered.decimal, "scale", 0)
211+
param_scale = getattr(parameterized_type, "scale", 0)
212+
covered_prec = getattr(covered.decimal, "precision", 0)
213+
param_prec = getattr(parameterized_type, "precision", 0)
196214
return not (
197-
violates_integer_option(
198-
covered_scale, param_scale, parameters
199-
)
200-
or violates_integer_option(
201-
covered_prec, param_prec, parameters
202-
)
215+
violates_integer_option(covered_scale, param_scale, parameters)
216+
or violates_integer_option(covered_prec, param_prec, parameters)
203217
)
204-
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext):
218+
if isinstance(
219+
parameterized_type, SubstraitTypeParser.PrecisionTimestampContext
220+
):
205221
if kind != "precision_timestamp":
206222
return False
207-
if not _check_nullability(check_nullability,parameterized_type,covered,kind):
223+
if not _check_nullability(
224+
check_nullability, parameterized_type, covered, kind
225+
):
208226
return False
209-
return True
210-
# covered_prec = getattr(covered.precision_timestamp, "precision", 0)
211-
# param_prec = getattr(parameterized_type, "precision", 0)
212-
# return covered_prec == param_prec
213-
227+
# return True
228+
covered_prec = getattr(covered.precision_timestamp, "precision", 0)
229+
param_prec = getattr(parameterized_type, "precision", 0)
230+
return not violates_integer_option(covered_prec, param_prec, parameters)
214231

215-
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext):
232+
if isinstance(
233+
parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext
234+
):
216235
if kind != "precision_timestamp_tz":
217236
return False
218-
if not _check_nullability(check_nullability,parameterized_type,covered,kind):
237+
if not _check_nullability(
238+
check_nullability, parameterized_type, covered, kind
239+
):
219240
return False
220-
return True
241+
# return True
221242
covered_prec = getattr(covered.precision_timestamp_tz, "precision", 0)
222-
param_prec = getattr(parameterized_type, "precision", 0)
223-
return covered_prec == param_prec
243+
param_prec = getattr(parameterized_type, "precision", 0)
244+
return not violates_integer_option(covered_prec, param_prec, parameters)
245+
224246
kind_mapping = {
225-
SubstraitTypeParser.ListContext: "list",
226-
SubstraitTypeParser.MapContext: "map",
227-
SubstraitTypeParser.StructContext: "struct",
228-
SubstraitTypeParser.UserDefinedContext: "user_defined",
247+
SubstraitTypeParser.ListContext: "list",
248+
SubstraitTypeParser.MapContext: "map",
249+
SubstraitTypeParser.StructContext: "struct",
250+
SubstraitTypeParser.UserDefinedContext: "user_defined",
229251
SubstraitTypeParser.PrecisionIntervalDayContext: "interval_day",
230252
}
231253

232254
for ctx_cls, expected_kind in kind_mapping.items():
233255
if isinstance(parameterized_type, ctx_cls):
234256
if kind != expected_kind:
235257
return False
236-
return _check_nullability(check_nullability,parameterized_type,covered,kind)
258+
return _check_nullability(
259+
check_nullability, parameterized_type, covered, kind
260+
)
237261
else:
238262
raise Exception(f"Unhandled type {type(parameterized_type)}")
239263

0 commit comments

Comments
 (0)