Skip to content

Commit 17f6513

Browse files
Merge pull request #604 from InfiniTensor/issue/602
Issue/602 - 给变量起个名字
2 parents b7d9252 + c8df7bd commit 17f6513

File tree

3 files changed

+50
-88
lines changed

3 files changed

+50
-88
lines changed

test/infinicore/framework/base.py

Lines changed: 36 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
@dataclass
1919
class TestResult:
2020
"""Test result data structure"""
21+
2122
success: bool
2223
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
2324
torch_time: float = 0.0
@@ -57,26 +58,29 @@ def __init__(
5758
self.inputs = []
5859

5960
# Process inputs - support both single TensorSpecs and tuples of TensorSpecs
60-
for inp in inputs:
61+
for i, inp in enumerate(inputs):
6162
if isinstance(inp, (list, tuple)):
6263
# Handle tuple/list of multiple TensorSpecs (e.g., for torch.cat)
6364
processed_tuple = []
64-
for item in inp:
65+
for j, item in enumerate(inp):
6566
if isinstance(item, (list, tuple)):
6667
# Nested tuple - recursively process
6768
nested_processed = []
68-
for nested_item in item:
69+
for k, nested_item in enumerate(item):
6970
if isinstance(nested_item, TensorSpec):
71+
nested_item.fill_name(f"in_{i}_{j}_{k}")
7072
nested_processed.append(nested_item)
7173
else:
7274
nested_processed.append(nested_item)
7375
processed_tuple.append(tuple(nested_processed))
7476
elif isinstance(item, TensorSpec):
77+
item.fill_name(f"in_{i}_{j}")
7578
processed_tuple.append(item)
7679
else:
7780
processed_tuple.append(item)
7881
self.inputs.append(tuple(processed_tuple))
7982
elif isinstance(inp, TensorSpec):
83+
inp.fill_name(f"in_{i}")
8084
self.inputs.append(inp)
8185
else:
8286
self.inputs.append(inp)
@@ -89,6 +93,10 @@ def __init__(
8993
self.tolerance = tolerance or {"atol": 1e-5, "rtol": 1e-3}
9094
self.output_count = output_count
9195

96+
if self.output_count > 1 and self.output_specs is not None:
97+
for idx, spec in enumerate(self.output_specs):
98+
spec.fill_name(f"out_{idx}")
99+
92100
# Validate output configuration
93101
if self.output_count == 1:
94102
if self.output_specs is not None:
@@ -124,45 +132,15 @@ def __str__(self):
124132
# Handle tuple inputs (e.g., for torch.cat)
125133
tuple_strs = []
126134
for item in inp:
127-
if hasattr(item, "is_scalar") and item.is_scalar:
128-
dtype_str = f", dtype={item.dtype}" if item.dtype else ""
129-
tuple_strs.append(f"scalar({item.value}{dtype_str})")
130-
elif hasattr(item, "shape"):
131-
dtype_str = f", {item.dtype}" if item.dtype else ""
132-
init_str = (
133-
f", init={item.init_mode}"
134-
if item.init_mode != TensorInitializer.RANDOM
135-
else ""
136-
)
137-
if hasattr(item, "strides") and item.strides:
138-
strides_str = f", strides={item.strides}"
139-
tuple_strs.append(
140-
f"tensor{item.shape}{strides_str}{dtype_str}{init_str}"
141-
)
142-
else:
143-
tuple_strs.append(
144-
f"tensor{item.shape}{dtype_str}{init_str}"
145-
)
135+
if isinstance(item, (list, tuple)):
136+
# Handle nested tuples
137+
nested_strs = []
138+
for nested_item in item:
139+
nested_strs.append(str(nested_item))
140+
tuple_strs.append(f"tuple({', '.join(nested_strs)})")
146141
else:
147142
tuple_strs.append(str(item))
148143
input_strs.append(f"tuple({'; '.join(tuple_strs)})")
149-
elif hasattr(inp, "is_scalar") and inp.is_scalar:
150-
dtype_str = f", dtype={inp.dtype}" if inp.dtype else ""
151-
input_strs.append(f"scalar({inp.value}{dtype_str})")
152-
elif hasattr(inp, "shape"):
153-
dtype_str = f", {inp.dtype}" if inp.dtype else ""
154-
init_str = (
155-
f", init={inp.init_mode}"
156-
if inp.init_mode != TensorInitializer.RANDOM
157-
else ""
158-
)
159-
if hasattr(inp, "strides") and inp.strides:
160-
strides_str = f", strides={inp.strides}"
161-
input_strs.append(
162-
f"tensor{inp.shape}{strides_str}{dtype_str}{init_str}"
163-
)
164-
else:
165-
input_strs.append(f"tensor{inp.shape}{dtype_str}{init_str}")
166144
else:
167145
input_strs.append(str(inp))
168146

@@ -175,48 +153,16 @@ def __str__(self):
175153
kwargs_strs = []
176154
for key, value in self.kwargs.items():
177155
if key == "out" and isinstance(value, int):
178-
kwargs_strs.append(f"{key}={value}")
156+
kwargs_strs.append(f"{key}={self.inputs[value].name}")
179157
else:
180158
kwargs_strs.append(f"{key}={value}")
181159

182-
# Handle output specifications
160+
# Handle output specifications using TensorSpec's __str__
183161
if self.output_count == 1 and self.output_spec:
184-
dtype_str = (
185-
f", {self.output_spec.dtype}" if self.output_spec.dtype else ""
186-
)
187-
init_str = (
188-
f", init={self.output_spec.init_mode}"
189-
if self.output_spec.init_mode != TensorInitializer.RANDOM
190-
else ""
191-
)
192-
if hasattr(self.output_spec, "strides") and self.output_spec.strides:
193-
strides_str = f", strides={self.output_spec.strides}"
194-
kwargs_strs.append(
195-
f"out=tensor{self.output_spec.shape}{strides_str}{dtype_str}{init_str}"
196-
)
197-
else:
198-
kwargs_strs.append(
199-
f"out=tensor{self.output_spec.shape}{dtype_str}{init_str}"
200-
)
162+
kwargs_strs.append(f"out={self.output_spec}")
201163
elif self.output_count > 1 and self.output_specs:
202-
output_strs = []
203164
for i, spec in enumerate(self.output_specs):
204-
dtype_str = f", {spec.dtype}" if spec.dtype else ""
205-
init_str = (
206-
f", init={spec.init_mode}"
207-
if spec.init_mode != TensorInitializer.RANDOM
208-
else ""
209-
)
210-
if hasattr(spec, "strides") and spec.strides:
211-
strides_str = f", strides={spec.strides}"
212-
output_strs.append(
213-
f"out_{i}=tensor{spec.shape}{strides_str}{dtype_str}{init_str}"
214-
)
215-
else:
216-
output_strs.append(
217-
f"out_{i}=tensor{spec.shape}{dtype_str}{init_str}"
218-
)
219-
kwargs_strs.extend(output_strs)
165+
kwargs_strs.append(f"out_{i}={spec}")
220166

221167
base_str += f", kwargs={{{'; '.join(kwargs_strs)}}}"
222168

@@ -300,11 +246,15 @@ def run_tests(self, devices, test_func, test_type="Test"):
300246
elif test_result.return_code == -2: # Skipped
301247
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented"
302248
self.skipped_tests.append(skip_msg)
303-
print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
249+
print(
250+
f"\033[93m⚠\033[0m Both operators not implemented - test skipped"
251+
)
304252
elif test_result.return_code == -3: # Partial
305253
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented"
306254
self.partial_tests.append(partial_msg)
307-
print(f"\033[93m⚠\033[0m One operator not implemented - running single operator without comparison")
255+
print(
256+
f"\033[93m⚠\033[0m One operator not implemented - running single operator without comparison"
257+
)
308258

309259
if self.config.verbose and test_result.return_code != 0:
310260
return False
@@ -315,14 +265,14 @@ def run_tests(self, devices, test_func, test_type="Test"):
315265
)
316266
print(f"\033[91m✗\033[0m {error_msg}")
317267
self.failed_tests.append(error_msg)
318-
268+
319269
# Create a failed TestResult
320270
failed_result = TestResult(
321271
success=False,
322272
return_code=-1,
323273
error_message=str(e),
324274
test_case=test_case,
325-
device=device
275+
device=device,
326276
)
327277
self.test_results.append(failed_result)
328278
# In verbose mode, print full traceback and stop execution
@@ -333,7 +283,11 @@ def run_tests(self, devices, test_func, test_type="Test"):
333283
if self.config.debug:
334284
raise
335285

336-
return len(self.failed_tests) == 0 and len(self.skipped_tests) == 0 and len(self.partial_tests) == 0
286+
return (
287+
len(self.failed_tests) == 0
288+
and len(self.skipped_tests) == 0
289+
and len(self.partial_tests) == 0
290+
)
337291

338292
def print_summary(self):
339293
"""
@@ -514,13 +468,13 @@ def run_test(self, device, test_case, config):
514468
TestResult: Test result object containing status and timing information
515469
"""
516470
device_str = torch_device_map[device]
517-
471+
518472
# Initialize test result
519473
test_result = TestResult(
520474
success=False,
521475
return_code=-1, # Default to failure
522476
test_case=test_case,
523-
device=device
477+
device=device,
524478
)
525479

526480
# Prepare inputs and kwargs with actual tensors

test/infinicore/framework/tensor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def __init__(
284284
self.is_scalar = is_scalar
285285
self.init_mode = init_mode
286286
self.kwargs = kwargs
287+
self.name = kwargs.get("name") if kwargs.get("name") else None
287288

288289
@classmethod
289290
def from_tensor(
@@ -339,10 +340,17 @@ def is_tensor_input(self):
339340
"""Check if this spec represents a tensor input (not scalar)"""
340341
return not self.is_scalar
341342

343+
def fill_name(self, name):
344+
if self.name is None:
345+
self.name = name
346+
342347
def __str__(self):
348+
name_str = f"{self.name}: " if self.name else ""
343349
if self.is_scalar:
344-
return f"scalar({self.value})"
350+
return f"{name_str}scalar({self.value})"
345351
else:
346352
strides_str = f", strides={self.strides}" if self.strides else ""
347-
dtype_str = f", dtype={self.dtype}" if self.dtype else ""
348-
return f"tensor{self.shape}{strides_str}{dtype_str}"
353+
dtype_str = (
354+
f", {str(self.dtype).replace("infinicore.", "")}" if self.dtype else ""
355+
)
356+
return f"{name_str}tensor{self.shape}{strides_str}{dtype_str}"

test/infinicore/ops/add.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def parse_test_cases():
6565
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
6666

6767
# Create typed tensor specs
68-
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
69-
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
70-
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)
68+
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype, name="a")
69+
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype, name="b")
70+
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype, name="c")
7171

7272
# Test Case 1: Out-of-place (return value)
7373
test_cases.append(

0 commit comments

Comments
 (0)