@@ -79,22 +79,16 @@ def extract_tensor_meta(meta):
7979# Class to capture arguments and turn into tensor references for TOSA OPs
8080class TosaArg :
8181 def __process_node (self , argument : torch .fx .Node ):
82- self .name = argument .name
82+ self .name : str = argument .name
8383 self .dtype , self .shape , self .dim_order = extract_tensor_meta (argument .meta )
8484
8585 def __process_list (self , argument ):
86- self .special = list (argument )
86+ self .special : list = list (argument )
8787
8888 def __process_number (self , argument : float | int ):
89- self .number = argument
89+ self .number : float | int = argument
9090
9191 def __init__ (self , argument : Any ) -> None :
92- self .name = None # type: ignore[assignment]
93- self .dtype = None
94- self .shape = None
95- self .dim_order = None
96- self .special = None
97-
9892 if argument is None :
9993 return
10094
@@ -114,3 +108,20 @@ def __init__(self, argument: Any) -> None:
114108 raise RuntimeError (
115109 f"Unhandled node input argument: { argument } , of type { type (argument )} "
116110 )
111+
112+ def __repr__ (self ):
113+ attrs = []
114+ if hasattr (self , "name" ):
115+ if self .name is not None :
116+ attrs .append (f"name={ self .name !r} " )
117+ if self .dtype is not None :
118+ attrs .append (f"dtype={ ts .DTypeNames [self .dtype ]} " )
119+ if self .shape is not None :
120+ attrs .append (f"shape={ self .shape !r} " )
121+ if self .dim_order is not None :
122+ attrs .append (f"dim_order={ self .dim_order !r} " )
123+ if hasattr (self , "special" ) and self .special is not None :
124+ attrs .append (f"special={ self .special !r} " )
125+ if hasattr (self , "number" ) and self .number is not None :
126+ attrs .append (f"number={ self .number !r} " )
127+ return f"{ self .__class__ .__name__ } ({ ', ' .join (attrs )} )"
0 commit comments