@@ -68,21 +68,24 @@ def normalize_substrait_type_names(typ: str) -> str:
6868 raise Exception (f"Unrecognized substrait type { typ } " )
6969
7070
71- def violates_integer_option (actual : int , option , parameters : dict ):
71+ def violates_integer_option (actual : int , option , parameters : dict , subset = False ):
72+ option_numeric = None
7273 if isinstance (option , SubstraitTypeParser .NumericLiteralContext ):
73- return actual ! = int (str (option .Number ()))
74+ option_numeric = int (str (option .Number ()))
7475 elif isinstance (option , SubstraitTypeParser .NumericParameterNameContext ):
7576 parameter_name = str (option .Identifier ())
76- if parameter_name in parameters and parameters [parameter_name ] != actual :
77- return True
78- else :
77+
78+ if parameter_name not in parameters :
7979 parameters [parameter_name ] = actual
80+ option_numeric = parameters [parameter_name ]
8081 else :
8182 raise Exception (
8283 f"Input should be either NumericLiteralContext or NumericParameterNameContext, got { type (option )} instead"
8384 )
84-
85- return False
85+ if subset :
86+ return actual < option_numeric
87+ else :
88+ return actual != option_numeric
8689
8790
8891def types_equal (type1 : Type , type2 : Type , check_nullability = False ):
@@ -111,7 +114,8 @@ def handle_parameter_cover(
111114 parameters [parameter_name ] = covered
112115 return True
113116
114- def _check_nullability (check_nullability ,parameterized_type ,covered ,kind ) -> bool :
117+
118+ def _check_nullability (check_nullability , parameterized_type , covered , kind ) -> bool :
115119 if not check_nullability :
116120 return True
117121 # The ANTLR context stores a Token called ``isnull`` – it is
@@ -121,6 +125,8 @@ def _check_nullability(check_nullability,parameterized_type,covered,kind) -> boo
121125 if getattr (parameterized_type , "isnull" , None ) is not None
122126 else Type .Nullability .NULLABILITY_REQUIRED
123127 )
128+ # if nullability == Type.Nullability.NULLABILITY_NULLABLE:
129+ # return True # is still true even if the covered is required
124130 # The protobuf message stores its own enum – we compare the two.
125131 covered_nullability = getattr (
126132 getattr (covered , kind ), # e.g. covered.varchar
@@ -129,6 +135,7 @@ def _check_nullability(check_nullability,parameterized_type,covered,kind) -> boo
129135 )
130136 return nullability == covered_nullability
131137
138+
132139def covers (
133140 covered : Type ,
134141 covering : SubstraitTypeParser .TypeLiteralContext ,
@@ -140,7 +147,6 @@ def covers(
140147 return handle_parameter_cover (
141148 covered , parameter_name , parameters , check_nullability
142149 )
143-
144150 covering : SubstraitTypeParser .TypeDefContext = covering .typeDef ()
145151
146152 any_type : SubstraitTypeParser .AnyTypeContext = covering .anyType ()
@@ -163,77 +169,95 @@ def covers(
163169 if isinstance (parameterized_type , SubstraitTypeParser .VarCharContext ):
164170 if kind != "varchar" :
165171 return False
166- if getattr (covered .varchar , "length" , 0 ) > getattr (parameterized_type , "length" , 0 ):
172+ if hasattr (parameterized_type , "length" ) and violates_integer_option (
173+ covered .varchar .length , parameterized_type .length , parameters
174+ ):
167175 return False
168- return True
169- return _check_nullability (check_nullability ,parameterized_type ,covered ,kind )
170176
177+ return _check_nullability (
178+ check_nullability , parameterized_type , covered , kind
179+ )
171180 if isinstance (parameterized_type , SubstraitTypeParser .FixedCharContext ):
172181 if kind != "fixed_char" :
173182 return False
174- if getattr (covered .fixed_char , "length" , 0 ) > getattr (parameterized_type , "length" , 0 ):
183+ if hasattr (parameterized_type , "length" ) and violates_integer_option (
184+ covered .fixed_char .length , parameterized_type .length , parameters
185+ ):
175186 return False
176- return True
177- return _check_nullability (check_nullability ,parameterized_type ,covered ,kind )
187+ return _check_nullability (
188+ check_nullability , parameterized_type , covered , kind
189+ )
178190
179191 if isinstance (parameterized_type , SubstraitTypeParser .FixedBinaryContext ):
180192 if kind != "fixed_binary" :
181193 return False
182- if getattr (covered .fixed_binary , "length" , 0 ) > getattr (parameterized_type , "length" , 0 ):
194+ if hasattr (parameterized_type , "length" ) and violates_integer_option (
195+ covered .fixed_binary .length , parameterized_type .length , parameters
196+ ):
183197 return False
184- return True
185- return _check_nullability (check_nullability ,parameterized_type ,covered ,kind )
198+ # return True
199+ return _check_nullability (
200+ check_nullability , parameterized_type , covered , kind
201+ )
186202 if isinstance (parameterized_type , SubstraitTypeParser .DecimalContext ):
187203 if kind != "decimal" :
188204 return False
189- if not _check_nullability (check_nullability ,parameterized_type ,covered ,kind ):
205+ if not _check_nullability (
206+ check_nullability , parameterized_type , covered , kind
207+ ):
190208 return False
191209 # precision / scale are both optional – a missing value means “no limit”.
192- covered_scale = getattr (covered .decimal , "scale" , 0 )
193- param_scale = getattr (parameterized_type , "scale" , 0 )
194- covered_prec = getattr (covered .decimal , "precision" , 0 )
195- param_prec = getattr (parameterized_type , "precision" , 0 )
210+ covered_scale = getattr (covered .decimal , "scale" , 0 )
211+ param_scale = getattr (parameterized_type , "scale" , 0 )
212+ covered_prec = getattr (covered .decimal , "precision" , 0 )
213+ param_prec = getattr (parameterized_type , "precision" , 0 )
196214 return not (
197- violates_integer_option (
198- covered_scale , param_scale , parameters
199- )
200- or violates_integer_option (
201- covered_prec , param_prec , parameters
202- )
215+ violates_integer_option (covered_scale , param_scale , parameters )
216+ or violates_integer_option (covered_prec , param_prec , parameters )
203217 )
204- if isinstance (parameterized_type , SubstraitTypeParser .PrecisionTimestampContext ):
218+ if isinstance (
219+ parameterized_type , SubstraitTypeParser .PrecisionTimestampContext
220+ ):
205221 if kind != "precision_timestamp" :
206222 return False
207- if not _check_nullability (check_nullability ,parameterized_type ,covered ,kind ):
223+ if not _check_nullability (
224+ check_nullability , parameterized_type , covered , kind
225+ ):
208226 return False
209- return True
210- # covered_prec = getattr(covered.precision_timestamp, "precision", 0)
211- # param_prec = getattr(parameterized_type, "precision", 0)
212- # return covered_prec == param_prec
213-
227+ # return True
228+ covered_prec = getattr (covered .precision_timestamp , "precision" , 0 )
229+ param_prec = getattr (parameterized_type , "precision" , 0 )
230+ return not violates_integer_option (covered_prec , param_prec , parameters )
214231
215- if isinstance (parameterized_type , SubstraitTypeParser .PrecisionTimestampTZContext ):
232+ if isinstance (
233+ parameterized_type , SubstraitTypeParser .PrecisionTimestampTZContext
234+ ):
216235 if kind != "precision_timestamp_tz" :
217236 return False
218- if not _check_nullability (check_nullability ,parameterized_type ,covered ,kind ):
237+ if not _check_nullability (
238+ check_nullability , parameterized_type , covered , kind
239+ ):
219240 return False
220- return True
241+ # return True
221242 covered_prec = getattr (covered .precision_timestamp_tz , "precision" , 0 )
222- param_prec = getattr (parameterized_type , "precision" , 0 )
223- return covered_prec == param_prec
243+ param_prec = getattr (parameterized_type , "precision" , 0 )
244+ return not violates_integer_option (covered_prec , param_prec , parameters )
245+
224246 kind_mapping = {
225- SubstraitTypeParser .ListContext : "list" ,
226- SubstraitTypeParser .MapContext : "map" ,
227- SubstraitTypeParser .StructContext : "struct" ,
228- SubstraitTypeParser .UserDefinedContext : "user_defined" ,
247+ SubstraitTypeParser .ListContext : "list" ,
248+ SubstraitTypeParser .MapContext : "map" ,
249+ SubstraitTypeParser .StructContext : "struct" ,
250+ SubstraitTypeParser .UserDefinedContext : "user_defined" ,
229251 SubstraitTypeParser .PrecisionIntervalDayContext : "interval_day" ,
230252 }
231253
232254 for ctx_cls , expected_kind in kind_mapping .items ():
233255 if isinstance (parameterized_type , ctx_cls ):
234256 if kind != expected_kind :
235257 return False
236- return _check_nullability (check_nullability ,parameterized_type ,covered ,kind )
258+ return _check_nullability (
259+ check_nullability , parameterized_type , covered , kind
260+ )
237261 else :
238262 raise Exception (f"Unhandled type { type (parameterized_type )} " )
239263
0 commit comments