@@ -62,21 +62,24 @@ def normalize_substrait_type_names(typ: str) -> str:
6262 raise Exception (f"Unrecognized substrait type { typ } " )
6363
6464
65- def violates_integer_option (actual : int , option , parameters : dict ):
65+ def violates_integer_option (actual : int , option , parameters : dict , subset = False ):
66+ option_numeric = None
6667 if isinstance (option , SubstraitTypeParser .NumericLiteralContext ):
67- return actual ! = int (str (option .Number ()))
68+ option_numeric = int (str (option .Number ()))
6869 elif isinstance (option , SubstraitTypeParser .NumericParameterNameContext ):
6970 parameter_name = str (option .Identifier ())
70- if parameter_name in parameters and parameters [parameter_name ] != actual :
71- return True
72- else :
71+
72+ if parameter_name not in parameters :
7373 parameters [parameter_name ] = actual
74+ option_numeric = parameters [parameter_name ]
7475 else :
7576 raise Exception (
7677 f"Input should be either NumericLiteralContext or NumericParameterNameContext, got { type (option )} instead"
7778 )
78-
79- return False
79+ if subset :
80+ return actual < option_numeric
81+ else :
82+ return actual != option_numeric
8083
8184
8285def types_equal (type1 : Type , type2 : Type , check_nullability = False ):
@@ -105,7 +108,8 @@ def handle_parameter_cover(
105108 parameters [parameter_name ] = covered
106109 return True
107110
108- def _check_nullability (check_nullability ,parameterized_type ,covered ,kind ) -> bool :
111+
112+ def _check_nullability (check_nullability , parameterized_type , covered , kind ) -> bool :
109113 if not check_nullability :
110114 return True
111115 # The ANTLR context stores a Token called ``isnull`` – it is
@@ -115,6 +119,8 @@ def _check_nullability(check_nullability,parameterized_type,covered,kind) -> boo
115119 if getattr (parameterized_type , "isnull" , None ) is not None
116120 else Type .Nullability .NULLABILITY_REQUIRED
117121 )
122+ # if nullability == Type.Nullability.NULLABILITY_NULLABLE:
123+ # return True # is still true even if the covered is required
118124 # The protobuf message stores its own enum – we compare the two.
119125 covered_nullability = getattr (
120126 getattr (covered , kind ), # e.g. covered.varchar
@@ -123,6 +129,7 @@ def _check_nullability(check_nullability,parameterized_type,covered,kind) -> boo
123129 )
124130 return nullability == covered_nullability
125131
132+
126133def covers (
127134 covered : Type ,
128135 covering : SubstraitTypeParser .TypeLiteralContext ,
@@ -134,7 +141,6 @@ def covers(
134141 return handle_parameter_cover (
135142 covered , parameter_name , parameters , check_nullability
136143 )
137-
138144 covering : SubstraitTypeParser .TypeDefContext = covering .typeDef ()
139145
140146 any_type : SubstraitTypeParser .AnyTypeContext = covering .anyType ()
@@ -157,77 +163,95 @@ def covers(
157163 if isinstance (parameterized_type , SubstraitTypeParser .VarCharContext ):
158164 if kind != "varchar" :
159165 return False
160- if getattr (covered .varchar , "length" , 0 ) > getattr (parameterized_type , "length" , 0 ):
166+ if hasattr (parameterized_type , "length" ) and violates_integer_option (
167+ covered .varchar .length , parameterized_type .length , parameters
168+ ):
161169 return False
162- return True
163- return _check_nullability (check_nullability ,parameterized_type ,covered ,kind )
164170
171+ return _check_nullability (
172+ check_nullability , parameterized_type , covered , kind
173+ )
165174 if isinstance (parameterized_type , SubstraitTypeParser .FixedCharContext ):
166175 if kind != "fixed_char" :
167176 return False
168- if getattr (covered .fixed_char , "length" , 0 ) > getattr (parameterized_type , "length" , 0 ):
177+ if hasattr (parameterized_type , "length" ) and violates_integer_option (
178+ covered .fixed_char .length , parameterized_type .length , parameters
179+ ):
169180 return False
170- return True
171- return _check_nullability (check_nullability ,parameterized_type ,covered ,kind )
181+ return _check_nullability (
182+ check_nullability , parameterized_type , covered , kind
183+ )
172184
173185 if isinstance (parameterized_type , SubstraitTypeParser .FixedBinaryContext ):
174186 if kind != "fixed_binary" :
175187 return False
176- if getattr (covered .fixed_binary , "length" , 0 ) > getattr (parameterized_type , "length" , 0 ):
188+ if hasattr (parameterized_type , "length" ) and violates_integer_option (
189+ covered .fixed_binary .length , parameterized_type .length , parameters
190+ ):
177191 return False
178- return True
179- return _check_nullability (check_nullability ,parameterized_type ,covered ,kind )
192+ # return True
193+ return _check_nullability (
194+ check_nullability , parameterized_type , covered , kind
195+ )
180196 if isinstance (parameterized_type , SubstraitTypeParser .DecimalContext ):
181197 if kind != "decimal" :
182198 return False
183- if not _check_nullability (check_nullability ,parameterized_type ,covered ,kind ):
199+ if not _check_nullability (
200+ check_nullability , parameterized_type , covered , kind
201+ ):
184202 return False
185203 # precision / scale are both optional – a missing value means “no limit”.
186- covered_scale = getattr (covered .decimal , "scale" , 0 )
187- param_scale = getattr (parameterized_type , "scale" , 0 )
188- covered_prec = getattr (covered .decimal , "precision" , 0 )
189- param_prec = getattr (parameterized_type , "precision" , 0 )
204+ covered_scale = getattr (covered .decimal , "scale" , 0 )
205+ param_scale = getattr (parameterized_type , "scale" , 0 )
206+ covered_prec = getattr (covered .decimal , "precision" , 0 )
207+ param_prec = getattr (parameterized_type , "precision" , 0 )
190208 return not (
191- violates_integer_option (
192- covered_scale , param_scale , parameters
193- )
194- or violates_integer_option (
195- covered_prec , param_prec , parameters
196- )
209+ violates_integer_option (covered_scale , param_scale , parameters )
210+ or violates_integer_option (covered_prec , param_prec , parameters )
197211 )
198- if isinstance (parameterized_type , SubstraitTypeParser .PrecisionTimestampContext ):
212+ if isinstance (
213+ parameterized_type , SubstraitTypeParser .PrecisionTimestampContext
214+ ):
199215 if kind != "precision_timestamp" :
200216 return False
201- if not _check_nullability (check_nullability ,parameterized_type ,covered ,kind ):
217+ if not _check_nullability (
218+ check_nullability , parameterized_type , covered , kind
219+ ):
202220 return False
203- return True
204- # covered_prec = getattr(covered.precision_timestamp, "precision", 0)
205- # param_prec = getattr(parameterized_type, "precision", 0)
206- # return covered_prec == param_prec
207-
221+ # return True
222+ covered_prec = getattr (covered .precision_timestamp , "precision" , 0 )
223+ param_prec = getattr (parameterized_type , "precision" , 0 )
224+ return not violates_integer_option (covered_prec , param_prec , parameters )
208225
209- if isinstance (parameterized_type , SubstraitTypeParser .PrecisionTimestampTZContext ):
226+ if isinstance (
227+ parameterized_type , SubstraitTypeParser .PrecisionTimestampTZContext
228+ ):
210229 if kind != "precision_timestamp_tz" :
211230 return False
212- if not _check_nullability (check_nullability ,parameterized_type ,covered ,kind ):
231+ if not _check_nullability (
232+ check_nullability , parameterized_type , covered , kind
233+ ):
213234 return False
214- return True
235+ # return True
215236 covered_prec = getattr (covered .precision_timestamp_tz , "precision" , 0 )
216- param_prec = getattr (parameterized_type , "precision" , 0 )
217- return covered_prec == param_prec
237+ param_prec = getattr (parameterized_type , "precision" , 0 )
238+ return not violates_integer_option (covered_prec , param_prec , parameters )
239+
218240 kind_mapping = {
219- SubstraitTypeParser .ListContext : "list" ,
220- SubstraitTypeParser .MapContext : "map" ,
221- SubstraitTypeParser .StructContext : "struct" ,
222- SubstraitTypeParser .UserDefinedContext : "user_defined" ,
241+ SubstraitTypeParser .ListContext : "list" ,
242+ SubstraitTypeParser .MapContext : "map" ,
243+ SubstraitTypeParser .StructContext : "struct" ,
244+ SubstraitTypeParser .UserDefinedContext : "user_defined" ,
223245 SubstraitTypeParser .PrecisionIntervalDayContext : "interval_day" ,
224246 }
225247
226248 for ctx_cls , expected_kind in kind_mapping .items ():
227249 if isinstance (parameterized_type , ctx_cls ):
228250 if kind != expected_kind :
229251 return False
230- return _check_nullability (check_nullability ,parameterized_type ,covered ,kind )
252+ return _check_nullability (
253+ check_nullability , parameterized_type , covered , kind
254+ )
231255 else :
232256 raise Exception (f"Unhandled type { type (parameterized_type )} " )
233257
0 commit comments