@@ -111,7 +111,7 @@ def handle_parameter_cover(
111111 parameters [parameter_name ] = covered
112112 return True
113113
114- def _check_nullability () -> bool :
114+ def _check_nullability (check_nullability , parameterized_type , covered , kind ) -> bool :
115115 if not check_nullability :
116116 return True
117117 # The ANTLR context stores a Token called ``isnull`` – it is
@@ -165,25 +165,25 @@ def covers(
165165 return False
166166 if getattr (covered .varchar , "length" , 0 ) > getattr (parameterized_type , "length" , 0 ):
167167 return False
168- return _check_nullability ()
168+ return _check_nullability (check_nullability , parameterized_type , covered , kind )
169169
170170 if isinstance (parameterized_type , SubstraitTypeParser .FixedCharContext ):
171171 if kind != "fixed_char" :
172172 return False
173173 if getattr (covered .fixed_char , "length" , 0 ) > getattr (parameterized_type , "length" , 0 ):
174174 return False
175- return _check_nullability ()
175+ return _check_nullability (check_nullability , parameterized_type , covered , kind )
176176
177177 if isinstance (parameterized_type , SubstraitTypeParser .FixedBinaryContext ):
178178 if kind != "fixed_binary" :
179179 return False
180180 if getattr (covered .fixed_binary , "length" , 0 ) > getattr (parameterized_type , "length" , 0 ):
181181 return False
182- return _check_nullability ()
182+ return _check_nullability (check_nullability , parameterized_type , covered , kind )
183183 if isinstance (parameterized_type , SubstraitTypeParser .DecimalContext ):
184184 if kind != "decimal" :
185185 return False
186- if not _check_nullability ():
186+ if not _check_nullability (check_nullability , parameterized_type , covered , kind ):
187187 return False
188188 # precision / scale are both optional – a missing value means “no limit”.
189189 covered_scale = getattr (covered .decimal , "scale" , 0 )
@@ -201,7 +201,7 @@ def covers(
201201 if isinstance (parameterized_type , SubstraitTypeParser .PrecisionTimestampContext ):
202202 if kind != "precision_timestamp" :
203203 return False
204- if not _check_nullability ():
204+ if not _check_nullability (check_nullability , parameterized_type , covered , kind ):
205205 return False
206206 covered_prec = getattr (covered .precision_timestamp , "precision" , 0 )
207207 param_prec = getattr (parameterized_type , "precision" , 0 )
@@ -211,7 +211,7 @@ def covers(
211211 if isinstance (parameterized_type , SubstraitTypeParser .PrecisionTimestampTZContext ):
212212 if kind != "precision_timestamp_tz" :
213213 return False
214- if not _check_nullability ():
214+ if not _check_nullability (check_nullability , parameterized_type , covered , kind ):
215215 return False
216216 covered_prec = getattr (covered .precision_timestamp_tz , "precision" , 0 )
217217 param_prec = getattr (parameterized_type , "precision" , 0 )
@@ -228,7 +228,7 @@ def covers(
228228 if isinstance (parameterized_type , ctx_cls ):
229229 if kind != expected_kind :
230230 return False
231- return _check_nullability ()
231+ return _check_nullability (check_nullability , parameterized_type , covered , kind )
232232 else :
233233 raise Exception (f"Unhandled type { type (parameterized_type )} " )
234234
0 commit comments