Skip to content

Commit 4653d80

Browse files
committed
issue/603 - reduced test tensor clones
1 parent bbda8d2 commit 4653d80

File tree

1 file changed

+73
-61
lines changed

1 file changed

+73
-61
lines changed

test/infinicore/framework/base.py

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,7 @@ def run_tests(self, devices, test_func, test_type="Test"):
262262
return False
263263

264264
except Exception as e:
265-
error_msg = (
266-
f"{test_case} - {InfiniDeviceNames[device]} - Error: {e}"
267-
)
265+
error_msg = f"Error: {e}"
268266
print(f"\033[91m✗\033[0m {error_msg}")
269267
self.failed_tests.append(error_msg)
270268

@@ -394,7 +392,7 @@ def _create_tensor_from_spec(self, spec, device):
394392
return spec.create_torch_tensor(device)
395393
return spec
396394

397-
def prepare_inputs_and_kwargs(self, test_case, device):
395+
def prepare_pytorch_inputs_and_kwargs(self, test_case, device):
398396
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors
399397
Supports tuple inputs for operators like torch.cat and TensorSpec in kwargs
400398
"""
@@ -457,6 +455,71 @@ def prepare_inputs_and_kwargs(self, test_case, device):
457455

458456
return inputs, kwargs
459457

458+
def prepare_infinicore_list(self, input_sequence, clone=False):
459+
cloned_tensors = []
460+
infini_list = []
461+
for item in input_sequence:
462+
if isinstance(item, torch.Tensor):
463+
if clone:
464+
cloned_item = item.clone().detach()
465+
infini_item = infinicore_tensor_from_torch(cloned_item)
466+
cloned_tensors.append(cloned_item)
467+
else:
468+
infini_item = infinicore_tensor_from_torch(item)
469+
else:
470+
infini_item = item
471+
infini_list.append(infini_item)
472+
return infini_list, cloned_tensors
473+
474+
def prepare_infinicore_inputs_and_kwargs(self, inputs, kwargs, comparison_target):
475+
cloned_tensors = []
476+
infini_inputs = []
477+
478+
# Prepare infinicore inputs - only clone if needed for comparison
479+
for i, inp in enumerate(inputs):
480+
if isinstance(inp, torch.Tensor):
481+
# Clone only if this input will be used for comparison
482+
if comparison_target == i:
483+
cloned_inp = inp.clone().detach()
484+
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
485+
cloned_tensors.append(cloned_inp)
486+
else:
487+
# For non-comparison inputs, we can use the original (but still need to convert)
488+
infini_tensor = infinicore_tensor_from_torch(inp)
489+
infini_inputs.append(infini_tensor)
490+
elif isinstance(inp, (tuple, list)):
491+
infini_list, cloned_list = self.prepare_infinicore_list(
492+
inp, comparison_target == i
493+
)
494+
infini_inputs.append(infini_list)
495+
cloned_tensors.append(cloned_list)
496+
else:
497+
infini_inputs.append(inp)
498+
499+
# Prepare infinicore kwargs
500+
infini_kwargs = {}
501+
for key, value in kwargs.items():
502+
if isinstance(value, torch.Tensor):
503+
# Check if this tensor is used for output comparison
504+
if key == "out" and comparison_target == "out":
505+
cloned_value = value.clone().detach()
506+
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
507+
cloned_tensors.append(cloned_value)
508+
elif key == "out" and isinstance(comparison_target, int):
509+
infini_kwargs[key] = infini_inputs[comparison_target]
510+
else:
511+
infini_kwargs[key] = infinicore_tensor_from_torch(value)
512+
elif isinstance(value, (tuple, list)):
513+
infini_list, cloned_list = self.prepare_infinicore_list(
514+
value, key == "out"
515+
)
516+
cloned_tensors.append(cloned_list)
517+
infini_kwargs[key] = infini_list
518+
else:
519+
infini_kwargs[key] = value
520+
521+
return infini_inputs, infini_kwargs, cloned_tensors
522+
460523
def run_test(self, device, test_case, config):
461524
"""
462525
Unified test execution flow
@@ -480,66 +543,15 @@ def run_test(self, device, test_case, config):
480543
)
481544

482545
# Prepare inputs and kwargs with actual tensors
483-
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
484-
485-
# For in-place operations on input tensors, we need to preserve the original state
486-
original_inputs = []
487-
if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
488-
# This is an in-place operation on an input tensor
489-
# Store original values for comparison
490-
for inp in inputs:
491-
if isinstance(inp, torch.Tensor):
492-
original_inputs.append(inp.clone().detach())
493-
else:
494-
original_inputs.append(inp)
495-
496-
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
497-
infini_inputs = []
498-
torch_input_clones = []
499-
500-
for inp in inputs:
501-
if isinstance(inp, torch.Tensor):
502-
cloned_inp = inp.clone().detach()
503-
torch_input_clones.append(cloned_inp)
504-
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
505-
infini_inputs.append(infini_tensor)
506-
else:
507-
infini_inputs.append(inp)
508-
509-
infini_kwargs = {}
510-
for key, value in kwargs.items():
511-
if isinstance(value, torch.Tensor):
512-
# Clone tensor and convert to infinicore
513-
cloned_value = value.clone().detach()
514-
torch_input_clones.append(cloned_value)
515-
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
516-
else:
517-
# Pass through non-tensor values (scalars, strings, etc.)
518-
infini_kwargs[key] = value
546+
inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device)
519547

520548
# Determine comparison target
521549
comparison_target = test_case.comparison_target
522550

523-
# Handle infinicore output
524-
infini_kwargs = kwargs.copy()
525-
if "out" in infini_kwargs:
526-
out_value = infini_kwargs["out"]
527-
if isinstance(out_value, torch.Tensor):
528-
# Single tensor output
529-
if isinstance(comparison_target, int):
530-
infini_kwargs["out"] = infini_inputs[comparison_target]
531-
else:
532-
cloned_out = out_value.clone().detach()
533-
torch_input_clones.append(cloned_out)
534-
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
535-
elif isinstance(out_value, (tuple, list)):
536-
# Multiple tensor outputs
537-
infini_outputs = []
538-
for tensor in out_value:
539-
cloned_tensor = tensor.clone().detach()
540-
torch_input_clones.append(cloned_tensor)
541-
infini_outputs.append(infinicore_tensor_from_torch(cloned_tensor))
542-
infini_kwargs["out"] = tuple(infini_outputs)
551+
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
552+
infini_inputs, infini_kwargs, cloned_tensors = (
553+
self.prepare_infinicore_inputs_and_kwargs(inputs, kwargs, comparison_target)
554+
)
543555

544556
# Check operator implementations
545557
torch_implemented = True
@@ -700,7 +712,7 @@ def run_test(self, device, test_case, config):
700712

701713
is_valid = compare_fn(infini_comparison, torch_comparison)
702714
if not is_valid:
703-
raise AssertionError(f"Result comparison failed for {test_case}")
715+
raise AssertionError(f"Result comparison failed.")
704716

705717
# ==========================================================================
706718
# UNIFIED BENCHMARKING LOGIC

0 commit comments

Comments
 (0)