@@ -285,18 +285,17 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
285285
286286 NOTE: When symbool and symfloat are supported bool and float lists will be stored boxed.
287287 """
288- elem_type = type (val_type )
289288
290- if elem_type == torch .BoolType :
289+ if isinstance ( val_type , torch .BoolType ) :
291290 return EValue (BoolList (typing .cast (List [bool ], val )))
292291
293- if elem_type == torch .IntType :
292+ if isinstance ( val_type , torch .IntType ) :
294293 return self ._emit_int_list (val )
295294
296- if elem_type == torch .FloatType :
295+ if isinstance ( val_type , torch .FloatType ) :
297296 return EValue (DoubleList (typing .cast (List [float ], val )))
298297
299- if elem_type == torch .TensorType :
298+ if isinstance ( val_type , torch .TensorType ) :
300299 values = []
301300 for v in val :
302301 assert isinstance (v , _AbstractValue )
@@ -308,10 +307,10 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
308307 values .append (v .id )
309308 return EValue (TensorList (values ))
310309
311- if elem_type == torch .OptionalType :
310+ if isinstance ( val_type , torch .OptionalType ) :
312311 # refine further
313- actual_type = typing . cast ( torch . OptionalType , val_type ) .getElementType ()
314- if type (actual_type ) == torch .TensorType :
312+ actual_type = val_type .getElementType ()
313+ if isinstance (actual_type , torch .TensorType ) :
315314 vals = []
316315 for v in val :
317316 if v is None :
@@ -437,9 +436,9 @@ def _constant_to_evalue( # noqa: C901
437436 val_type = torch .ListType (
438437 self ._get_list_tuple_jit_type (val ) # pyre-ignore
439438 )
440- if type (val_type ) == torch .OptionalType :
439+ if isinstance (val_type , torch .OptionalType ) :
441440 val_type = val_type .getElementType ()
442- assert type (val_type ) == torch .ListType
441+ assert isinstance (val_type , torch .ListType )
443442 return self ._emit_list (
444443 typing .cast (List [_Argument ], val ),
445444 typing .cast (_SchemaType , val_type .getElementType ()),
0 commit comments