Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/infinicore/ops/random_sample/random_sample_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static void calculate(

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(indices->device()), &desc,
indices->desc(), logits->desc()));
cache.put(seed, desc);
} else {
Expand Down
134 changes: 73 additions & 61 deletions test/infinicore/framework/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,7 @@ def run_tests(self, devices, test_func, test_type="Test"):
return False

except Exception as e:
error_msg = (
f"{test_case} - {InfiniDeviceNames[device]} - Error: {e}"
)
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)

Expand Down Expand Up @@ -392,7 +390,7 @@ def _create_tensor_from_spec(self, spec, device):
return spec.create_torch_tensor(device)
return spec

def prepare_inputs_and_kwargs(self, test_case, device):
def prepare_pytorch_inputs_and_kwargs(self, test_case, device):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors
Supports tuple inputs for operators like torch.cat and TensorSpec in kwargs
"""
Expand Down Expand Up @@ -455,6 +453,71 @@ def prepare_inputs_and_kwargs(self, test_case, device):

return inputs, kwargs

def prepare_infinicore_list(self, input_sequence, clone=False):
cloned_tensors = []
infini_list = []
for item in input_sequence:
if isinstance(item, torch.Tensor):
if clone:
cloned_item = item.clone().detach()
infini_item = infinicore_tensor_from_torch(cloned_item)
cloned_tensors.append(cloned_item)
else:
infini_item = infinicore_tensor_from_torch(item)
else:
infini_item = item
infini_list.append(infini_item)
return infini_list, cloned_tensors

def prepare_infinicore_inputs_and_kwargs(self, inputs, kwargs, comparison_target):
cloned_tensors = []
infini_inputs = []

# Prepare infinicore inputs - only clone if needed for comparison
for i, inp in enumerate(inputs):
if isinstance(inp, torch.Tensor):
# Clone only if this input will be used for comparison
if comparison_target == i:
cloned_inp = inp.clone().detach()
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
cloned_tensors.append(cloned_inp)
else:
# For non-comparison inputs, we can use the original (but still need to convert)
infini_tensor = infinicore_tensor_from_torch(inp)
infini_inputs.append(infini_tensor)
elif isinstance(inp, (tuple, list)):
infini_list, cloned_list = self.prepare_infinicore_list(
inp, comparison_target == i
)
infini_inputs.append(infini_list)
cloned_tensors.append(cloned_list)
else:
infini_inputs.append(inp)

# Prepare infinicore kwargs
infini_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, torch.Tensor):
# Check if this tensor is used for output comparison
if key == "out" and comparison_target == "out":
cloned_value = value.clone().detach()
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
cloned_tensors.append(cloned_value)
elif key == "out" and isinstance(comparison_target, int):
infini_kwargs[key] = infini_inputs[comparison_target]
else:
infini_kwargs[key] = infinicore_tensor_from_torch(value)
elif isinstance(value, (tuple, list)):
infini_list, cloned_list = self.prepare_infinicore_list(
value, key == "out"
)
cloned_tensors.append(cloned_list)
infini_kwargs[key] = infini_list
else:
infini_kwargs[key] = value

return infini_inputs, infini_kwargs, cloned_tensors

def run_test(self, device, test_case, config):
"""
Unified test execution flow
Expand All @@ -478,66 +541,15 @@ def run_test(self, device, test_case, config):
)

# Prepare inputs and kwargs with actual tensors
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)

# For in-place operations on input tensors, we need to preserve the original state
original_inputs = []
if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
# This is an in-place operation on an input tensor
# Store original values for comparison
for inp in inputs:
if isinstance(inp, torch.Tensor):
original_inputs.append(inp.clone().detach())
else:
original_inputs.append(inp)

# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_inputs = []
torch_input_clones = []

for inp in inputs:
if isinstance(inp, torch.Tensor):
cloned_inp = inp.clone().detach()
torch_input_clones.append(cloned_inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
infini_inputs.append(infini_tensor)
else:
infini_inputs.append(inp)

infini_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, torch.Tensor):
# Clone tensor and convert to infinicore
cloned_value = value.clone().detach()
torch_input_clones.append(cloned_value)
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
else:
# Pass through non-tensor values (scalars, strings, etc.)
infini_kwargs[key] = value
inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device)

# Determine comparison target
comparison_target = test_case.comparison_target

# Handle infinicore output
infini_kwargs = kwargs.copy()
if "out" in infini_kwargs:
out_value = infini_kwargs["out"]
if isinstance(out_value, torch.Tensor):
# Single tensor output
if isinstance(comparison_target, int):
infini_kwargs["out"] = infini_inputs[comparison_target]
else:
cloned_out = out_value.clone().detach()
torch_input_clones.append(cloned_out)
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
elif isinstance(out_value, (tuple, list)):
# Multiple tensor outputs
infini_outputs = []
for tensor in out_value:
cloned_tensor = tensor.clone().detach()
torch_input_clones.append(cloned_tensor)
infini_outputs.append(infinicore_tensor_from_torch(cloned_tensor))
infini_kwargs["out"] = tuple(infini_outputs)
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_inputs, infini_kwargs, cloned_tensors = (
self.prepare_infinicore_inputs_and_kwargs(inputs, kwargs, comparison_target)
)

# Check operator implementations
torch_implemented = True
Expand Down Expand Up @@ -698,7 +710,7 @@ def run_test(self, device, test_case, config):

is_valid = compare_fn(infini_comparison, torch_comparison)
if not is_valid:
raise AssertionError(f"Result comparison failed for {test_case}")
raise AssertionError(f"Result comparison failed.")

# ==========================================================================
# UNIFIED BENCHMARKING LOGIC
Expand Down