Skip to content

Commit f1b1082

Browse files
giospadaGitHub Enterprise
authored andcommitted
Fix: Cover behavior of parametrized types (#4)
Fix: Cover behavior of parametrized types and increased the test coverage for extension_registry
1 parent 079e0aa commit f1b1082

File tree

2 files changed

+263
-139
lines changed

2 files changed

+263
-139
lines changed

src/substrait/extension_registry.py

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,24 @@ def normalize_substrait_type_names(typ: str) -> str:
6262
raise Exception(f"Unrecognized substrait type {typ}")
6363

6464

65-
def violates_integer_option(actual: int, option, parameters: dict):
65+
def violates_integer_option(actual: int, option, parameters: dict, subset=False):
66+
option_numeric = None
6667
if isinstance(option, SubstraitTypeParser.NumericLiteralContext):
67-
return actual != int(str(option.Number()))
68+
option_numeric = int(str(option.Number()))
6869
elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext):
6970
parameter_name = str(option.Identifier())
70-
if parameter_name in parameters and parameters[parameter_name] != actual:
71-
return True
72-
else:
71+
72+
if parameter_name not in parameters:
7373
parameters[parameter_name] = actual
74+
option_numeric = parameters[parameter_name]
7475
else:
7576
raise Exception(
7677
f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead"
7778
)
78-
79-
return False
79+
if subset:
80+
return actual < option_numeric
81+
else:
82+
return actual != option_numeric
8083

8184

8285
def types_equal(type1: Type, type2: Type, check_nullability=False):
@@ -105,7 +108,8 @@ def handle_parameter_cover(
105108
parameters[parameter_name] = covered
106109
return True
107110

108-
def _check_nullability(check_nullability,parameterized_type,covered,kind) -> bool:
111+
112+
def _check_nullability(check_nullability, parameterized_type, covered, kind) -> bool:
109113
if not check_nullability:
110114
return True
111115
# The ANTLR context stores a Token called ``isnull`` – it is
@@ -115,6 +119,8 @@ def _check_nullability(check_nullability,parameterized_type,covered,kind) -> boo
115119
if getattr(parameterized_type, "isnull", None) is not None
116120
else Type.Nullability.NULLABILITY_REQUIRED
117121
)
122+
# if nullability == Type.Nullability.NULLABILITY_NULLABLE:
123+
# return True # is still true even if the covered is required
118124
# The protobuf message stores its own enum – we compare the two.
119125
covered_nullability = getattr(
120126
getattr(covered, kind), # e.g. covered.varchar
@@ -123,6 +129,7 @@ def _check_nullability(check_nullability,parameterized_type,covered,kind) -> boo
123129
)
124130
return nullability == covered_nullability
125131

132+
126133
def covers(
127134
covered: Type,
128135
covering: SubstraitTypeParser.TypeLiteralContext,
@@ -134,7 +141,6 @@ def covers(
134141
return handle_parameter_cover(
135142
covered, parameter_name, parameters, check_nullability
136143
)
137-
138144
covering: SubstraitTypeParser.TypeDefContext = covering.typeDef()
139145

140146
any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType()
@@ -157,77 +163,95 @@ def covers(
157163
if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext):
158164
if kind != "varchar":
159165
return False
160-
if getattr(covered.varchar, "length", 0) > getattr(parameterized_type, "length", 0):
166+
if hasattr(parameterized_type, "length") and violates_integer_option(
167+
covered.varchar.length, parameterized_type.length, parameters
168+
):
161169
return False
162-
return True
163-
return _check_nullability(check_nullability,parameterized_type,covered,kind)
164170

171+
return _check_nullability(
172+
check_nullability, parameterized_type, covered, kind
173+
)
165174
if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext):
166175
if kind != "fixed_char":
167176
return False
168-
if getattr(covered.fixed_char, "length", 0) > getattr(parameterized_type, "length", 0):
177+
if hasattr(parameterized_type, "length") and violates_integer_option(
178+
covered.fixed_char.length, parameterized_type.length, parameters
179+
):
169180
return False
170-
return True
171-
return _check_nullability(check_nullability,parameterized_type,covered,kind)
181+
return _check_nullability(
182+
check_nullability, parameterized_type, covered, kind
183+
)
172184

173185
if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext):
174186
if kind != "fixed_binary":
175187
return False
176-
if getattr(covered.fixed_binary, "length", 0) > getattr(parameterized_type, "length", 0):
188+
if hasattr(parameterized_type, "length") and violates_integer_option(
189+
covered.fixed_binary.length, parameterized_type.length, parameters
190+
):
177191
return False
178-
return True
179-
return _check_nullability(check_nullability,parameterized_type,covered,kind)
192+
# return True
193+
return _check_nullability(
194+
check_nullability, parameterized_type, covered, kind
195+
)
180196
if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
181197
if kind != "decimal":
182198
return False
183-
if not _check_nullability(check_nullability,parameterized_type,covered,kind):
199+
if not _check_nullability(
200+
check_nullability, parameterized_type, covered, kind
201+
):
184202
return False
185203
# precision / scale are both optional – a missing value means “no limit”.
186-
covered_scale = getattr(covered.decimal, "scale", 0)
187-
param_scale = getattr(parameterized_type, "scale", 0)
188-
covered_prec = getattr(covered.decimal, "precision", 0)
189-
param_prec = getattr(parameterized_type, "precision", 0)
204+
covered_scale = getattr(covered.decimal, "scale", 0)
205+
param_scale = getattr(parameterized_type, "scale", 0)
206+
covered_prec = getattr(covered.decimal, "precision", 0)
207+
param_prec = getattr(parameterized_type, "precision", 0)
190208
return not (
191-
violates_integer_option(
192-
covered_scale, param_scale, parameters
193-
)
194-
or violates_integer_option(
195-
covered_prec, param_prec, parameters
196-
)
209+
violates_integer_option(covered_scale, param_scale, parameters)
210+
or violates_integer_option(covered_prec, param_prec, parameters)
197211
)
198-
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext):
212+
if isinstance(
213+
parameterized_type, SubstraitTypeParser.PrecisionTimestampContext
214+
):
199215
if kind != "precision_timestamp":
200216
return False
201-
if not _check_nullability(check_nullability,parameterized_type,covered,kind):
217+
if not _check_nullability(
218+
check_nullability, parameterized_type, covered, kind
219+
):
202220
return False
203-
return True
204-
# covered_prec = getattr(covered.precision_timestamp, "precision", 0)
205-
# param_prec = getattr(parameterized_type, "precision", 0)
206-
# return covered_prec == param_prec
207-
221+
# return True
222+
covered_prec = getattr(covered.precision_timestamp, "precision", 0)
223+
param_prec = getattr(parameterized_type, "precision", 0)
224+
return not violates_integer_option(covered_prec, param_prec, parameters)
208225

209-
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext):
226+
if isinstance(
227+
parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext
228+
):
210229
if kind != "precision_timestamp_tz":
211230
return False
212-
if not _check_nullability(check_nullability,parameterized_type,covered,kind):
231+
if not _check_nullability(
232+
check_nullability, parameterized_type, covered, kind
233+
):
213234
return False
214-
return True
235+
# return True
215236
covered_prec = getattr(covered.precision_timestamp_tz, "precision", 0)
216-
param_prec = getattr(parameterized_type, "precision", 0)
217-
return covered_prec == param_prec
237+
param_prec = getattr(parameterized_type, "precision", 0)
238+
return not violates_integer_option(covered_prec, param_prec, parameters)
239+
218240
kind_mapping = {
219-
SubstraitTypeParser.ListContext: "list",
220-
SubstraitTypeParser.MapContext: "map",
221-
SubstraitTypeParser.StructContext: "struct",
222-
SubstraitTypeParser.UserDefinedContext: "user_defined",
241+
SubstraitTypeParser.ListContext: "list",
242+
SubstraitTypeParser.MapContext: "map",
243+
SubstraitTypeParser.StructContext: "struct",
244+
SubstraitTypeParser.UserDefinedContext: "user_defined",
223245
SubstraitTypeParser.PrecisionIntervalDayContext: "interval_day",
224246
}
225247

226248
for ctx_cls, expected_kind in kind_mapping.items():
227249
if isinstance(parameterized_type, ctx_cls):
228250
if kind != expected_kind:
229251
return False
230-
return _check_nullability(check_nullability,parameterized_type,covered,kind)
252+
return _check_nullability(
253+
check_nullability, parameterized_type, covered, kind
254+
)
231255
else:
232256
raise Exception(f"Unhandled type {type(parameterized_type)}")
233257

0 commit comments

Comments
 (0)