1818@dataclass
1919class TestResult :
2020 """Test result data structure"""
21+
2122 success : bool
2223 return_code : int # 0: success, -1: failure, -2: skipped, -3: partial
2324 torch_time : float = 0.0
@@ -57,26 +58,29 @@ def __init__(
5758 self .inputs = []
5859
5960 # Process inputs - support both single TensorSpecs and tuples of TensorSpecs
60- for inp in inputs :
61+ for i , inp in enumerate ( inputs ) :
6162 if isinstance (inp , (list , tuple )):
6263 # Handle tuple/list of multiple TensorSpecs (e.g., for torch.cat)
6364 processed_tuple = []
64- for item in inp :
65+ for j , item in enumerate ( inp ) :
6566 if isinstance (item , (list , tuple )):
6667 # Nested tuple - recursively process
6768 nested_processed = []
68- for nested_item in item :
69+ for k , nested_item in enumerate ( item ) :
6970 if isinstance (nested_item , TensorSpec ):
71+ nested_item .fill_name (f"in_{ i } _{ j } _{ k } " )
7072 nested_processed .append (nested_item )
7173 else :
7274 nested_processed .append (nested_item )
7375 processed_tuple .append (tuple (nested_processed ))
7476 elif isinstance (item , TensorSpec ):
77+ item .fill_name (f"in_{ i } _{ j } " )
7578 processed_tuple .append (item )
7679 else :
7780 processed_tuple .append (item )
7881 self .inputs .append (tuple (processed_tuple ))
7982 elif isinstance (inp , TensorSpec ):
83+ inp .fill_name (f"in_{ i } " )
8084 self .inputs .append (inp )
8185 else :
8286 self .inputs .append (inp )
@@ -89,6 +93,12 @@ def __init__(
8993 self .tolerance = tolerance or {"atol" : 1e-5 , "rtol" : 1e-3 }
9094 self .output_count = output_count
9195
96+ if self .output_count == 1 and self .output_spec is not None :
97+ self .output_spec .fill_name ("out" )
98+ elif self .output_count > 1 and self .output_specs is not None :
99+ for idx , spec in enumerate (self .output_specs ):
100+ spec .fill_name (f"output_{ idx } " )
101+
92102 # Validate output configuration
93103 if self .output_count == 1 :
94104 if self .output_specs is not None :
@@ -124,45 +134,15 @@ def __str__(self):
124134 # Handle tuple inputs (e.g., for torch.cat)
125135 tuple_strs = []
126136 for item in inp :
127- if hasattr (item , "is_scalar" ) and item .is_scalar :
128- dtype_str = f", dtype={ item .dtype } " if item .dtype else ""
129- tuple_strs .append (f"scalar({ item .value } { dtype_str } )" )
130- elif hasattr (item , "shape" ):
131- dtype_str = f", { item .dtype } " if item .dtype else ""
132- init_str = (
133- f", init={ item .init_mode } "
134- if item .init_mode != TensorInitializer .RANDOM
135- else ""
136- )
137- if hasattr (item , "strides" ) and item .strides :
138- strides_str = f", strides={ item .strides } "
139- tuple_strs .append (
140- f"tensor{ item .shape } { strides_str } { dtype_str } { init_str } "
141- )
142- else :
143- tuple_strs .append (
144- f"tensor{ item .shape } { dtype_str } { init_str } "
145- )
137+ if isinstance (item , (list , tuple )):
138+ # Handle nested tuples
139+ nested_strs = []
140+ for nested_item in item :
141+ nested_strs .append (str (nested_item ))
142+ tuple_strs .append (f"tuple({ ', ' .join (nested_strs )} )" )
146143 else :
147144 tuple_strs .append (str (item ))
148145 input_strs .append (f"tuple({ '; ' .join (tuple_strs )} )" )
149- elif hasattr (inp , "is_scalar" ) and inp .is_scalar :
150- dtype_str = f", dtype={ inp .dtype } " if inp .dtype else ""
151- input_strs .append (f"scalar({ inp .value } { dtype_str } )" )
152- elif hasattr (inp , "shape" ):
153- dtype_str = f", { inp .dtype } " if inp .dtype else ""
154- init_str = (
155- f", init={ inp .init_mode } "
156- if inp .init_mode != TensorInitializer .RANDOM
157- else ""
158- )
159- if hasattr (inp , "strides" ) and inp .strides :
160- strides_str = f", strides={ inp .strides } "
161- input_strs .append (
162- f"tensor{ inp .shape } { strides_str } { dtype_str } { init_str } "
163- )
164- else :
165- input_strs .append (f"tensor{ inp .shape } { dtype_str } { init_str } " )
166146 else :
167147 input_strs .append (str (inp ))
168148
@@ -175,48 +155,16 @@ def __str__(self):
175155 kwargs_strs = []
176156 for key , value in self .kwargs .items ():
177157 if key == "out" and isinstance (value , int ):
178- kwargs_strs .append (f"{ key } ={ value } " )
158+ kwargs_strs .append (f"{ key } ={ self . inputs [ value ]. name } " )
179159 else :
180160 kwargs_strs .append (f"{ key } ={ value } " )
181161
182- # Handle output specifications
162+ # Handle output specifications using TensorSpec's __str__
183163 if self .output_count == 1 and self .output_spec :
184- dtype_str = (
185- f", { self .output_spec .dtype } " if self .output_spec .dtype else ""
186- )
187- init_str = (
188- f", init={ self .output_spec .init_mode } "
189- if self .output_spec .init_mode != TensorInitializer .RANDOM
190- else ""
191- )
192- if hasattr (self .output_spec , "strides" ) and self .output_spec .strides :
193- strides_str = f", strides={ self .output_spec .strides } "
194- kwargs_strs .append (
195- f"out=tensor{ self .output_spec .shape } { strides_str } { dtype_str } { init_str } "
196- )
197- else :
198- kwargs_strs .append (
199- f"out=tensor{ self .output_spec .shape } { dtype_str } { init_str } "
200- )
164+ kwargs_strs .append (f"out={ self .output_spec } " )
201165 elif self .output_count > 1 and self .output_specs :
202- output_strs = []
203166 for i , spec in enumerate (self .output_specs ):
204- dtype_str = f", { spec .dtype } " if spec .dtype else ""
205- init_str = (
206- f", init={ spec .init_mode } "
207- if spec .init_mode != TensorInitializer .RANDOM
208- else ""
209- )
210- if hasattr (spec , "strides" ) and spec .strides :
211- strides_str = f", strides={ spec .strides } "
212- output_strs .append (
213- f"out_{ i } =tensor{ spec .shape } { strides_str } { dtype_str } { init_str } "
214- )
215- else :
216- output_strs .append (
217- f"out_{ i } =tensor{ spec .shape } { dtype_str } { init_str } "
218- )
219- kwargs_strs .extend (output_strs )
167+ kwargs_strs .append (f"out_{ i } ={ spec } " )
220168
221169 base_str += f", kwargs={{{ '; ' .join (kwargs_strs )} }}"
222170
@@ -300,11 +248,15 @@ def run_tests(self, devices, test_func, test_type="Test"):
300248 elif test_result .return_code == - 2 : # Skipped
301249 skip_msg = f"{ test_case } - { InfiniDeviceNames [device ]} - Both operators not implemented"
302250 self .skipped_tests .append (skip_msg )
303- print (f"\033 [93m⚠\033 [0m Both operators not implemented - test skipped" )
251+ print (
252+ f"\033 [93m⚠\033 [0m Both operators not implemented - test skipped"
253+ )
304254 elif test_result .return_code == - 3 : # Partial
305255 partial_msg = f"{ test_case } - { InfiniDeviceNames [device ]} - One operator not implemented"
306256 self .partial_tests .append (partial_msg )
307- print (f"\033 [93m⚠\033 [0m One operator not implemented - running single operator without comparison" )
257+ print (
258+ f"\033 [93m⚠\033 [0m One operator not implemented - running single operator without comparison"
259+ )
308260
309261 if self .config .verbose and test_result .return_code != 0 :
310262 return False
@@ -315,14 +267,14 @@ def run_tests(self, devices, test_func, test_type="Test"):
315267 )
316268 print (f"\033 [91m✗\033 [0m { error_msg } " )
317269 self .failed_tests .append (error_msg )
318-
270+
319271 # Create a failed TestResult
320272 failed_result = TestResult (
321273 success = False ,
322274 return_code = - 1 ,
323275 error_message = str (e ),
324276 test_case = test_case ,
325- device = device
277+ device = device ,
326278 )
327279 self .test_results .append (failed_result )
328280 # In verbose mode, print full traceback and stop execution
@@ -333,7 +285,11 @@ def run_tests(self, devices, test_func, test_type="Test"):
333285 if self .config .debug :
334286 raise
335287
336- return len (self .failed_tests ) == 0 and len (self .skipped_tests ) == 0 and len (self .partial_tests ) == 0
288+ return (
289+ len (self .failed_tests ) == 0
290+ and len (self .skipped_tests ) == 0
291+ and len (self .partial_tests ) == 0
292+ )
337293
338294 def print_summary (self ):
339295 """
@@ -514,13 +470,13 @@ def run_test(self, device, test_case, config):
514470 TestResult: Test result object containing status and timing information
515471 """
516472 device_str = torch_device_map [device ]
517-
473+
518474 # Initialize test result
519475 test_result = TestResult (
520476 success = False ,
521477 return_code = - 1 , # Default to failure
522478 test_case = test_case ,
523- device = device
479+ device = device ,
524480 )
525481
526482 # Prepare inputs and kwargs with actual tensors
0 commit comments