1818@dataclass
1919class TestResult :
2020 """Test result data structure"""
21+
2122 success : bool
2223 return_code : int # 0: success, -1: failure, -2: skipped, -3: partial
2324 torch_time : float = 0.0
@@ -57,26 +58,29 @@ def __init__(
5758 self .inputs = []
5859
5960 # Process inputs - support both single TensorSpecs and tuples of TensorSpecs
60- for inp in inputs :
61+ for i , inp in enumerate ( inputs ) :
6162 if isinstance (inp , (list , tuple )):
6263 # Handle tuple/list of multiple TensorSpecs (e.g., for torch.cat)
6364 processed_tuple = []
64- for item in inp :
65+ for j , item in enumerate ( inp ) :
6566 if isinstance (item , (list , tuple )):
6667 # Nested tuple - recursively process
6768 nested_processed = []
68- for nested_item in item :
69+ for k , nested_item in enumerate ( item ) :
6970 if isinstance (nested_item , TensorSpec ):
71+ nested_item .fill_name (f"in_{ i } _{ j } _{ k } " )
7072 nested_processed .append (nested_item )
7173 else :
7274 nested_processed .append (nested_item )
7375 processed_tuple .append (tuple (nested_processed ))
7476 elif isinstance (item , TensorSpec ):
77+ item .fill_name (f"in_{ i } _{ j } " )
7578 processed_tuple .append (item )
7679 else :
7780 processed_tuple .append (item )
7881 self .inputs .append (tuple (processed_tuple ))
7982 elif isinstance (inp , TensorSpec ):
83+ inp .fill_name (f"in_{ i } " )
8084 self .inputs .append (inp )
8185 else :
8286 self .inputs .append (inp )
@@ -89,6 +93,10 @@ def __init__(
8993 self .tolerance = tolerance or {"atol" : 1e-5 , "rtol" : 1e-3 }
9094 self .output_count = output_count
9195
96+ if self .output_count > 1 and self .output_specs is not None :
97+ for idx , spec in enumerate (self .output_specs ):
98+ spec .fill_name (f"out_{ idx } " )
99+
92100 # Validate output configuration
93101 if self .output_count == 1 :
94102 if self .output_specs is not None :
@@ -124,45 +132,15 @@ def __str__(self):
124132 # Handle tuple inputs (e.g., for torch.cat)
125133 tuple_strs = []
126134 for item in inp :
127- if hasattr (item , "is_scalar" ) and item .is_scalar :
128- dtype_str = f", dtype={ item .dtype } " if item .dtype else ""
129- tuple_strs .append (f"scalar({ item .value } { dtype_str } )" )
130- elif hasattr (item , "shape" ):
131- dtype_str = f", { item .dtype } " if item .dtype else ""
132- init_str = (
133- f", init={ item .init_mode } "
134- if item .init_mode != TensorInitializer .RANDOM
135- else ""
136- )
137- if hasattr (item , "strides" ) and item .strides :
138- strides_str = f", strides={ item .strides } "
139- tuple_strs .append (
140- f"tensor{ item .shape } { strides_str } { dtype_str } { init_str } "
141- )
142- else :
143- tuple_strs .append (
144- f"tensor{ item .shape } { dtype_str } { init_str } "
145- )
135+ if isinstance (item , (list , tuple )):
136+ # Handle nested tuples
137+ nested_strs = []
138+ for nested_item in item :
139+ nested_strs .append (str (nested_item ))
140+ tuple_strs .append (f"tuple({ ', ' .join (nested_strs )} )" )
146141 else :
147142 tuple_strs .append (str (item ))
148143 input_strs .append (f"tuple({ '; ' .join (tuple_strs )} )" )
149- elif hasattr (inp , "is_scalar" ) and inp .is_scalar :
150- dtype_str = f", dtype={ inp .dtype } " if inp .dtype else ""
151- input_strs .append (f"scalar({ inp .value } { dtype_str } )" )
152- elif hasattr (inp , "shape" ):
153- dtype_str = f", { inp .dtype } " if inp .dtype else ""
154- init_str = (
155- f", init={ inp .init_mode } "
156- if inp .init_mode != TensorInitializer .RANDOM
157- else ""
158- )
159- if hasattr (inp , "strides" ) and inp .strides :
160- strides_str = f", strides={ inp .strides } "
161- input_strs .append (
162- f"tensor{ inp .shape } { strides_str } { dtype_str } { init_str } "
163- )
164- else :
165- input_strs .append (f"tensor{ inp .shape } { dtype_str } { init_str } " )
166144 else :
167145 input_strs .append (str (inp ))
168146
@@ -175,48 +153,16 @@ def __str__(self):
175153 kwargs_strs = []
176154 for key , value in self .kwargs .items ():
177155 if key == "out" and isinstance (value , int ):
178- kwargs_strs .append (f"{ key } ={ value } " )
156+ kwargs_strs .append (f"{ key } ={ self . inputs [ value ]. name } " )
179157 else :
180158 kwargs_strs .append (f"{ key } ={ value } " )
181159
182- # Handle output specifications
160+ # Handle output specifications using TensorSpec's __str__
183161 if self .output_count == 1 and self .output_spec :
184- dtype_str = (
185- f", { self .output_spec .dtype } " if self .output_spec .dtype else ""
186- )
187- init_str = (
188- f", init={ self .output_spec .init_mode } "
189- if self .output_spec .init_mode != TensorInitializer .RANDOM
190- else ""
191- )
192- if hasattr (self .output_spec , "strides" ) and self .output_spec .strides :
193- strides_str = f", strides={ self .output_spec .strides } "
194- kwargs_strs .append (
195- f"out=tensor{ self .output_spec .shape } { strides_str } { dtype_str } { init_str } "
196- )
197- else :
198- kwargs_strs .append (
199- f"out=tensor{ self .output_spec .shape } { dtype_str } { init_str } "
200- )
162+ kwargs_strs .append (f"out={ self .output_spec } " )
201163 elif self .output_count > 1 and self .output_specs :
202- output_strs = []
203164 for i , spec in enumerate (self .output_specs ):
204- dtype_str = f", { spec .dtype } " if spec .dtype else ""
205- init_str = (
206- f", init={ spec .init_mode } "
207- if spec .init_mode != TensorInitializer .RANDOM
208- else ""
209- )
210- if hasattr (spec , "strides" ) and spec .strides :
211- strides_str = f", strides={ spec .strides } "
212- output_strs .append (
213- f"out_{ i } =tensor{ spec .shape } { strides_str } { dtype_str } { init_str } "
214- )
215- else :
216- output_strs .append (
217- f"out_{ i } =tensor{ spec .shape } { dtype_str } { init_str } "
218- )
219- kwargs_strs .extend (output_strs )
165+ kwargs_strs .append (f"out_{ i } ={ spec } " )
220166
221167 base_str += f", kwargs={{{ '; ' .join (kwargs_strs )} }}"
222168
@@ -300,11 +246,15 @@ def run_tests(self, devices, test_func, test_type="Test"):
300246 elif test_result .return_code == - 2 : # Skipped
301247 skip_msg = f"{ test_case } - { InfiniDeviceNames [device ]} - Both operators not implemented"
302248 self .skipped_tests .append (skip_msg )
303- print (f"\033 [93m⚠\033 [0m Both operators not implemented - test skipped" )
249+ print (
250+ f"\033 [93m⚠\033 [0m Both operators not implemented - test skipped"
251+ )
304252 elif test_result .return_code == - 3 : # Partial
305253 partial_msg = f"{ test_case } - { InfiniDeviceNames [device ]} - One operator not implemented"
306254 self .partial_tests .append (partial_msg )
307- print (f"\033 [93m⚠\033 [0m One operator not implemented - running single operator without comparison" )
255+ print (
256+ f"\033 [93m⚠\033 [0m One operator not implemented - running single operator without comparison"
257+ )
308258
309259 if self .config .verbose and test_result .return_code != 0 :
310260 return False
@@ -315,14 +265,14 @@ def run_tests(self, devices, test_func, test_type="Test"):
315265 )
316266 print (f"\033 [91m✗\033 [0m { error_msg } " )
317267 self .failed_tests .append (error_msg )
318-
268+
319269 # Create a failed TestResult
320270 failed_result = TestResult (
321271 success = False ,
322272 return_code = - 1 ,
323273 error_message = str (e ),
324274 test_case = test_case ,
325- device = device
275+ device = device ,
326276 )
327277 self .test_results .append (failed_result )
328278 # In verbose mode, print full traceback and stop execution
@@ -333,7 +283,11 @@ def run_tests(self, devices, test_func, test_type="Test"):
333283 if self .config .debug :
334284 raise
335285
336- return len (self .failed_tests ) == 0 and len (self .skipped_tests ) == 0 and len (self .partial_tests ) == 0
286+ return (
287+ len (self .failed_tests ) == 0
288+ and len (self .skipped_tests ) == 0
289+ and len (self .partial_tests ) == 0
290+ )
337291
338292 def print_summary (self ):
339293 """
@@ -514,13 +468,13 @@ def run_test(self, device, test_case, config):
514468 TestResult: Test result object containing status and timing information
515469 """
516470 device_str = torch_device_map [device ]
517-
471+
518472 # Initialize test result
519473 test_result = TestResult (
520474 success = False ,
521475 return_code = - 1 , # Default to failure
522476 test_case = test_case ,
523- device = device
477+ device = device ,
524478 )
525479
526480 # Prepare inputs and kwargs with actual tensors
0 commit comments