Skip to content

Commit af94be0

Browse files
committed
fix: check nullability
1 parent c689be5 commit af94be0

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/substrait/extension_registry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def handle_parameter_cover(
111111
parameters[parameter_name] = covered
112112
return True
113113

114-
def _check_nullability() -> bool:
114+
def _check_nullability(check_nullability,parameterized_type,covered,kind) -> bool:
115115
if not check_nullability:
116116
return True
117117
# The ANTLR context stores a Token called ``isnull`` – it is
@@ -165,25 +165,25 @@ def covers(
165165
return False
166166
if getattr(covered.varchar, "length", 0) > getattr(parameterized_type, "length", 0):
167167
return False
168-
return _check_nullability()
168+
return _check_nullability(check_nullability,parameterized_type,covered,kind)
169169

170170
if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext):
171171
if kind != "fixed_char":
172172
return False
173173
if getattr(covered.fixed_char, "length", 0) > getattr(parameterized_type, "length", 0):
174174
return False
175-
return _check_nullability()
175+
return _check_nullability(check_nullability,parameterized_type,covered,kind)
176176

177177
if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext):
178178
if kind != "fixed_binary":
179179
return False
180180
if getattr(covered.fixed_binary, "length", 0) > getattr(parameterized_type, "length", 0):
181181
return False
182-
return _check_nullability()
182+
return _check_nullability(check_nullability,parameterized_type,covered,kind)
183183
if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
184184
if kind != "decimal":
185185
return False
186-
if not _check_nullability():
186+
if not _check_nullability(check_nullability,parameterized_type,covered,kind):
187187
return False
188188
# precision / scale are both optional – a missing value means “no limit”.
189189
covered_scale = getattr(covered.decimal, "scale", 0)
@@ -201,7 +201,7 @@ def covers(
201201
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext):
202202
if kind != "precision_timestamp":
203203
return False
204-
if not _check_nullability():
204+
if not _check_nullability(check_nullability,parameterized_type,covered,kind):
205205
return False
206206
covered_prec = getattr(covered.precision_timestamp, "precision", 0)
207207
param_prec = getattr(parameterized_type, "precision", 0)
@@ -211,7 +211,7 @@ def covers(
211211
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext):
212212
if kind != "precision_timestamp_tz":
213213
return False
214-
if not _check_nullability():
214+
if not _check_nullability(check_nullability,parameterized_type,covered,kind):
215215
return False
216216
covered_prec = getattr(covered.precision_timestamp_tz, "precision", 0)
217217
param_prec = getattr(parameterized_type, "precision", 0)
@@ -228,7 +228,7 @@ def covers(
228228
if isinstance(parameterized_type, ctx_cls):
229229
if kind != expected_kind:
230230
return False
231-
return _check_nullability()
231+
return _check_nullability(check_nullability,parameterized_type,covered,kind)
232232
else:
233233
raise Exception(f"Unhandled type {type(parameterized_type)}")
234234

0 commit comments

Comments
 (0)