Skip to content

Commit e114dc1

Browse files
committed
issue/603 - reduced test tensor clones
1 parent c8df7bd commit e114dc1

File tree

2 files changed

+74
-62
lines changed

2 files changed

+74
-62
lines changed

src/infinicore/ops/random_sample/random_sample_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ static void calculate(
3535

3636
if (!desc_opt) {
3737
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
38-
context::getInfiniopHandle(), &desc,
38+
context::getInfiniopHandle(indices->device()), &desc,
3939
indices->desc(), logits->desc()));
4040
cache.put(seed, desc);
4141
} else {

test/infinicore/framework/base.py

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,7 @@ def run_tests(self, devices, test_func, test_type="Test"):
260260
return False
261261

262262
except Exception as e:
263-
error_msg = (
264-
f"{test_case} - {InfiniDeviceNames[device]} - Error: {e}"
265-
)
263+
error_msg = f"Error: {e}"
266264
print(f"\033[91m✗\033[0m {error_msg}")
267265
self.failed_tests.append(error_msg)
268266

@@ -392,7 +390,7 @@ def _create_tensor_from_spec(self, spec, device):
392390
return spec.create_torch_tensor(device)
393391
return spec
394392

395-
def prepare_inputs_and_kwargs(self, test_case, device):
393+
def prepare_pytorch_inputs_and_kwargs(self, test_case, device):
396394
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors
397395
Supports tuple inputs for operators like torch.cat and TensorSpec in kwargs
398396
"""
@@ -455,6 +453,71 @@ def prepare_inputs_and_kwargs(self, test_case, device):
455453

456454
return inputs, kwargs
457455

456+
def prepare_infinicore_list(self, input_sequence, clone=False):
457+
cloned_tensors = []
458+
infini_list = []
459+
for item in input_sequence:
460+
if isinstance(item, torch.Tensor):
461+
if clone:
462+
cloned_item = item.clone().detach()
463+
infini_item = infinicore_tensor_from_torch(cloned_item)
464+
cloned_tensors.append(cloned_item)
465+
else:
466+
infini_item = infinicore_tensor_from_torch(item)
467+
else:
468+
infini_item = item
469+
infini_list.append(infini_item)
470+
return infini_list, cloned_tensors
471+
472+
def prepare_infinicore_inputs_and_kwargs(self, inputs, kwargs, comparison_target):
473+
cloned_tensors = []
474+
infini_inputs = []
475+
476+
# Prepare infinicore inputs - only clone if needed for comparison
477+
for i, inp in enumerate(inputs):
478+
if isinstance(inp, torch.Tensor):
479+
# Clone only if this input will be used for comparison
480+
if comparison_target == i:
481+
cloned_inp = inp.clone().detach()
482+
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
483+
cloned_tensors.append(cloned_inp)
484+
else:
485+
# For non-comparison inputs, we can use the original (but still need to convert)
486+
infini_tensor = infinicore_tensor_from_torch(inp)
487+
infini_inputs.append(infini_tensor)
488+
elif isinstance(inp, (tuple, list)):
489+
infini_list, cloned_list = self.prepare_infinicore_list(
490+
inp, comparison_target == i
491+
)
492+
infini_inputs.append(infini_list)
493+
cloned_tensors.append(cloned_list)
494+
else:
495+
infini_inputs.append(inp)
496+
497+
# Prepare infinicore kwargs
498+
infini_kwargs = {}
499+
for key, value in kwargs.items():
500+
if isinstance(value, torch.Tensor):
501+
# Check if this tensor is used for output comparison
502+
if key == "out" and comparison_target == "out":
503+
cloned_value = value.clone().detach()
504+
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
505+
cloned_tensors.append(cloned_value)
506+
elif key == "out" and isinstance(comparison_target, int):
507+
infini_kwargs[key] = infini_inputs[comparison_target]
508+
else:
509+
infini_kwargs[key] = infinicore_tensor_from_torch(value)
510+
elif isinstance(value, (tuple, list)):
511+
infini_list, cloned_list = self.prepare_infinicore_list(
512+
value, key == "out"
513+
)
514+
cloned_tensors.append(cloned_list)
515+
infini_kwargs[key] = infini_list
516+
else:
517+
infini_kwargs[key] = value
518+
519+
return infini_inputs, infini_kwargs, cloned_tensors
520+
458521
def run_test(self, device, test_case, config):
459522
"""
460523
Unified test execution flow
@@ -478,66 +541,15 @@ def run_test(self, device, test_case, config):
478541
)
479542

480543
# Prepare inputs and kwargs with actual tensors
481-
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
482-
483-
# For in-place operations on input tensors, we need to preserve the original state
484-
original_inputs = []
485-
if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
486-
# This is an in-place operation on an input tensor
487-
# Store original values for comparison
488-
for inp in inputs:
489-
if isinstance(inp, torch.Tensor):
490-
original_inputs.append(inp.clone().detach())
491-
else:
492-
original_inputs.append(inp)
493-
494-
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
495-
infini_inputs = []
496-
torch_input_clones = []
497-
498-
for inp in inputs:
499-
if isinstance(inp, torch.Tensor):
500-
cloned_inp = inp.clone().detach()
501-
torch_input_clones.append(cloned_inp)
502-
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
503-
infini_inputs.append(infini_tensor)
504-
else:
505-
infini_inputs.append(inp)
506-
507-
infini_kwargs = {}
508-
for key, value in kwargs.items():
509-
if isinstance(value, torch.Tensor):
510-
# Clone tensor and convert to infinicore
511-
cloned_value = value.clone().detach()
512-
torch_input_clones.append(cloned_value)
513-
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
514-
else:
515-
# Pass through non-tensor values (scalars, strings, etc.)
516-
infini_kwargs[key] = value
544+
inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device)
517545

518546
# Determine comparison target
519547
comparison_target = test_case.comparison_target
520548

521-
# Handle infinicore output
522-
infini_kwargs = kwargs.copy()
523-
if "out" in infini_kwargs:
524-
out_value = infini_kwargs["out"]
525-
if isinstance(out_value, torch.Tensor):
526-
# Single tensor output
527-
if isinstance(comparison_target, int):
528-
infini_kwargs["out"] = infini_inputs[comparison_target]
529-
else:
530-
cloned_out = out_value.clone().detach()
531-
torch_input_clones.append(cloned_out)
532-
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
533-
elif isinstance(out_value, (tuple, list)):
534-
# Multiple tensor outputs
535-
infini_outputs = []
536-
for tensor in out_value:
537-
cloned_tensor = tensor.clone().detach()
538-
torch_input_clones.append(cloned_tensor)
539-
infini_outputs.append(infinicore_tensor_from_torch(cloned_tensor))
540-
infini_kwargs["out"] = tuple(infini_outputs)
549+
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
550+
infini_inputs, infini_kwargs, cloned_tensors = (
551+
self.prepare_infinicore_inputs_and_kwargs(inputs, kwargs, comparison_target)
552+
)
541553

542554
# Check operator implementations
543555
torch_implemented = True
@@ -698,7 +710,7 @@ def run_test(self, device, test_case, config):
698710

699711
is_valid = compare_fn(infini_comparison, torch_comparison)
700712
if not is_valid:
701-
raise AssertionError(f"Result comparison failed for {test_case}")
713+
raise AssertionError(f"Result comparison failed.")
702714

703715
# ==========================================================================
704716
# UNIFIED BENCHMARKING LOGIC

0 commit comments

Comments
 (0)