Skip to content

Commit 450bb54

Browse files
committed
issue/602 - added an arg naming option and its auto-fill
1 parent b7d9252 commit 450bb54

File tree

3 files changed

+52
-88
lines changed

3 files changed

+52
-88
lines changed

test/infinicore/framework/base.py

Lines changed: 38 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
@dataclass
1919
class TestResult:
2020
"""Test result data structure"""
21+
2122
success: bool
2223
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
2324
torch_time: float = 0.0
@@ -57,26 +58,29 @@ def __init__(
5758
self.inputs = []
5859

5960
# Process inputs - support both single TensorSpecs and tuples of TensorSpecs
60-
for inp in inputs:
61+
for i, inp in enumerate(inputs):
6162
if isinstance(inp, (list, tuple)):
6263
# Handle tuple/list of multiple TensorSpecs (e.g., for torch.cat)
6364
processed_tuple = []
64-
for item in inp:
65+
for j, item in enumerate(inp):
6566
if isinstance(item, (list, tuple)):
6667
# Nested tuple - recursively process
6768
nested_processed = []
68-
for nested_item in item:
69+
for k, nested_item in enumerate(item):
6970
if isinstance(nested_item, TensorSpec):
71+
nested_item.fill_name(f"in_{i}_{j}_{k}")
7072
nested_processed.append(nested_item)
7173
else:
7274
nested_processed.append(nested_item)
7375
processed_tuple.append(tuple(nested_processed))
7476
elif isinstance(item, TensorSpec):
77+
item.fill_name(f"in_{i}_{j}")
7578
processed_tuple.append(item)
7679
else:
7780
processed_tuple.append(item)
7881
self.inputs.append(tuple(processed_tuple))
7982
elif isinstance(inp, TensorSpec):
83+
inp.fill_name(f"in_{i}")
8084
self.inputs.append(inp)
8185
else:
8286
self.inputs.append(inp)
@@ -89,6 +93,12 @@ def __init__(
8993
self.tolerance = tolerance or {"atol": 1e-5, "rtol": 1e-3}
9094
self.output_count = output_count
9195

96+
if self.output_count == 1 and self.output_spec is not None:
97+
self.output_spec.fill_name("out")
98+
elif self.output_count > 1 and self.output_specs is not None:
99+
for idx, spec in enumerate(self.output_specs):
100+
spec.fill_name(f"output_{idx}")
101+
92102
# Validate output configuration
93103
if self.output_count == 1:
94104
if self.output_specs is not None:
@@ -124,45 +134,15 @@ def __str__(self):
124134
# Handle tuple inputs (e.g., for torch.cat)
125135
tuple_strs = []
126136
for item in inp:
127-
if hasattr(item, "is_scalar") and item.is_scalar:
128-
dtype_str = f", dtype={item.dtype}" if item.dtype else ""
129-
tuple_strs.append(f"scalar({item.value}{dtype_str})")
130-
elif hasattr(item, "shape"):
131-
dtype_str = f", {item.dtype}" if item.dtype else ""
132-
init_str = (
133-
f", init={item.init_mode}"
134-
if item.init_mode != TensorInitializer.RANDOM
135-
else ""
136-
)
137-
if hasattr(item, "strides") and item.strides:
138-
strides_str = f", strides={item.strides}"
139-
tuple_strs.append(
140-
f"tensor{item.shape}{strides_str}{dtype_str}{init_str}"
141-
)
142-
else:
143-
tuple_strs.append(
144-
f"tensor{item.shape}{dtype_str}{init_str}"
145-
)
137+
if isinstance(item, (list, tuple)):
138+
# Handle nested tuples
139+
nested_strs = []
140+
for nested_item in item:
141+
nested_strs.append(str(nested_item))
142+
tuple_strs.append(f"tuple({', '.join(nested_strs)})")
146143
else:
147144
tuple_strs.append(str(item))
148145
input_strs.append(f"tuple({'; '.join(tuple_strs)})")
149-
elif hasattr(inp, "is_scalar") and inp.is_scalar:
150-
dtype_str = f", dtype={inp.dtype}" if inp.dtype else ""
151-
input_strs.append(f"scalar({inp.value}{dtype_str})")
152-
elif hasattr(inp, "shape"):
153-
dtype_str = f", {inp.dtype}" if inp.dtype else ""
154-
init_str = (
155-
f", init={inp.init_mode}"
156-
if inp.init_mode != TensorInitializer.RANDOM
157-
else ""
158-
)
159-
if hasattr(inp, "strides") and inp.strides:
160-
strides_str = f", strides={inp.strides}"
161-
input_strs.append(
162-
f"tensor{inp.shape}{strides_str}{dtype_str}{init_str}"
163-
)
164-
else:
165-
input_strs.append(f"tensor{inp.shape}{dtype_str}{init_str}")
166146
else:
167147
input_strs.append(str(inp))
168148

@@ -175,48 +155,16 @@ def __str__(self):
175155
kwargs_strs = []
176156
for key, value in self.kwargs.items():
177157
if key == "out" and isinstance(value, int):
178-
kwargs_strs.append(f"{key}={value}")
158+
kwargs_strs.append(f"{key}={self.inputs[value].name}")
179159
else:
180160
kwargs_strs.append(f"{key}={value}")
181161

182-
# Handle output specifications
162+
# Handle output specifications using TensorSpec's __str__
183163
if self.output_count == 1 and self.output_spec:
184-
dtype_str = (
185-
f", {self.output_spec.dtype}" if self.output_spec.dtype else ""
186-
)
187-
init_str = (
188-
f", init={self.output_spec.init_mode}"
189-
if self.output_spec.init_mode != TensorInitializer.RANDOM
190-
else ""
191-
)
192-
if hasattr(self.output_spec, "strides") and self.output_spec.strides:
193-
strides_str = f", strides={self.output_spec.strides}"
194-
kwargs_strs.append(
195-
f"out=tensor{self.output_spec.shape}{strides_str}{dtype_str}{init_str}"
196-
)
197-
else:
198-
kwargs_strs.append(
199-
f"out=tensor{self.output_spec.shape}{dtype_str}{init_str}"
200-
)
164+
kwargs_strs.append(f"out={self.output_spec}")
201165
elif self.output_count > 1 and self.output_specs:
202-
output_strs = []
203166
for i, spec in enumerate(self.output_specs):
204-
dtype_str = f", {spec.dtype}" if spec.dtype else ""
205-
init_str = (
206-
f", init={spec.init_mode}"
207-
if spec.init_mode != TensorInitializer.RANDOM
208-
else ""
209-
)
210-
if hasattr(spec, "strides") and spec.strides:
211-
strides_str = f", strides={spec.strides}"
212-
output_strs.append(
213-
f"out_{i}=tensor{spec.shape}{strides_str}{dtype_str}{init_str}"
214-
)
215-
else:
216-
output_strs.append(
217-
f"out_{i}=tensor{spec.shape}{dtype_str}{init_str}"
218-
)
219-
kwargs_strs.extend(output_strs)
167+
kwargs_strs.append(f"out_{i}={spec}")
220168

221169
base_str += f", kwargs={{{'; '.join(kwargs_strs)}}}"
222170

@@ -300,11 +248,15 @@ def run_tests(self, devices, test_func, test_type="Test"):
300248
elif test_result.return_code == -2: # Skipped
301249
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented"
302250
self.skipped_tests.append(skip_msg)
303-
print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
251+
print(
252+
f"\033[93m⚠\033[0m Both operators not implemented - test skipped"
253+
)
304254
elif test_result.return_code == -3: # Partial
305255
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented"
306256
self.partial_tests.append(partial_msg)
307-
print(f"\033[93m⚠\033[0m One operator not implemented - running single operator without comparison")
257+
print(
258+
f"\033[93m⚠\033[0m One operator not implemented - running single operator without comparison"
259+
)
308260

309261
if self.config.verbose and test_result.return_code != 0:
310262
return False
@@ -315,14 +267,14 @@ def run_tests(self, devices, test_func, test_type="Test"):
315267
)
316268
print(f"\033[91m✗\033[0m {error_msg}")
317269
self.failed_tests.append(error_msg)
318-
270+
319271
# Create a failed TestResult
320272
failed_result = TestResult(
321273
success=False,
322274
return_code=-1,
323275
error_message=str(e),
324276
test_case=test_case,
325-
device=device
277+
device=device,
326278
)
327279
self.test_results.append(failed_result)
328280
# In verbose mode, print full traceback and stop execution
@@ -333,7 +285,11 @@ def run_tests(self, devices, test_func, test_type="Test"):
333285
if self.config.debug:
334286
raise
335287

336-
return len(self.failed_tests) == 0 and len(self.skipped_tests) == 0 and len(self.partial_tests) == 0
288+
return (
289+
len(self.failed_tests) == 0
290+
and len(self.skipped_tests) == 0
291+
and len(self.partial_tests) == 0
292+
)
337293

338294
def print_summary(self):
339295
"""
@@ -514,13 +470,13 @@ def run_test(self, device, test_case, config):
514470
TestResult: Test result object containing status and timing information
515471
"""
516472
device_str = torch_device_map[device]
517-
473+
518474
# Initialize test result
519475
test_result = TestResult(
520476
success=False,
521477
return_code=-1, # Default to failure
522478
test_case=test_case,
523-
device=device
479+
device=device,
524480
)
525481

526482
# Prepare inputs and kwargs with actual tensors

test/infinicore/framework/tensor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def __init__(
284284
self.is_scalar = is_scalar
285285
self.init_mode = init_mode
286286
self.kwargs = kwargs
287+
self.name = kwargs.get("name") if kwargs.get("name") else None
287288

288289
@classmethod
289290
def from_tensor(
@@ -339,10 +340,17 @@ def is_tensor_input(self):
339340
"""Check if this spec represents a tensor input (not scalar)"""
340341
return not self.is_scalar
341342

343+
def fill_name(self, name):
344+
if self.name is None:
345+
self.name = name
346+
342347
def __str__(self):
348+
name_str = f"{self.name}: " if self.name else ""
343349
if self.is_scalar:
344-
return f"scalar({self.value})"
350+
return f"{name_str}scalar({self.value})"
345351
else:
346352
strides_str = f", strides={self.strides}" if self.strides else ""
347-
dtype_str = f", dtype={self.dtype}" if self.dtype else ""
348-
return f"tensor{self.shape}{strides_str}{dtype_str}"
353+
dtype_str = (
354+
f", {str(self.dtype).replace("infinicore.", "")}" if self.dtype else ""
355+
)
356+
return f"{name_str}tensor{self.shape}{strides_str}{dtype_str}"

test/infinicore/ops/add.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def parse_test_cases():
6565
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
6666

6767
# Create typed tensor specs
68-
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
69-
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
70-
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)
68+
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype, name="a")
69+
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype, name="b")
70+
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype, name="c")
7171

7272
# Test Case 1: Out-of-place (return value)
7373
test_cases.append(

0 commit comments

Comments
 (0)