diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..c586c25 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -2,6 +2,7 @@ abs, add, addmm, + all, bitwise_and, bitwise_not, bitwise_or, @@ -35,7 +36,11 @@ sin, softmax, sub, + sum, tanh, + topk, + var, + var_mean, ) __all__ = [ @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "sum", + "topk", + "var", + "var_mean", + "all", ] diff --git a/src/ntops/kernels/all.py b/src/ntops/kernels/all.py new file mode 100644 index 0000000..c3b7616 --- /dev/null +++ b/src/ntops/kernels/all.py @@ -0,0 +1,22 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, output): + val_block = input[0] + bool_block = val_block != 0 + res = ntl.min(bool_block, axis=0) + output[0] = res + + +def premake(ndim, dim, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + tensors = ( + Tensor(ndim, other=1), + Tensor(ndim, dtype="int8"), + ) + return arrangement_, application, tensors diff --git a/src/ntops/kernels/sum.py b/src/ntops/kernels/sum.py new file mode 100644 index 0000000..467df9a --- /dev/null +++ b/src/ntops/kernels/sum.py @@ -0,0 +1,48 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, output): + accumulator = 0.0 + + for i in range(input.shape[0]): + block_sum = ntl.sum(input[i], axis=0) + accumulator += block_sum + + output[0] = ntl.cast(accumulator, output.dtype.dtype) + + +def premake(ndim, dim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors + + +def arrangement_all_elements(input, output, block_size=None): + input = input.flatten().tile((block_size,)) + output = output.tile((1,)) + return input, output + + +def application_all_elements(input, output): + output[0] = ntl.sum(input, 0) + + +def premake_all_elements(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement_all_elements, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(1, dtype=dtype), + ) + + return arrangement_, application_all_elements, tensors diff --git a/src/ntops/kernels/topk.py b/src/ntops/kernels/topk.py new file mode 100644 index 0000000..7a642d9 --- /dev/null +++ b/src/ntops/kernels/topk.py @@ -0,0 +1,62 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, values, indices, k, largest): + val_block = input[0] + + idx_block = ntl.arange(0, val_block.shape[0]) + + res_vals = ntl.zeros(val_block.shape, dtype=val_block.dtype) + res_idxs = ntl.zeros(val_block.shape, dtype=indices.dtype.dtype) + output_range = ntl.arange(0, val_block.shape[0]) + + if largest: + working_val = val_block + else: + working_val = -val_block + + sentinel = float("-inf") + + for i in range(k): + current_max_val = ntl.max(working_val, axis=0) + current_max_idx = ntl.argmax(working_val, axis=0) + + real_val = -current_max_val if not largest else current_max_val + real_val = ntl.cast(real_val, res_vals.dtype) + + target_mask = output_range == i + res_vals = ntl.where(target_mask, real_val, res_vals) + res_idxs = ntl.where(target_mask, current_max_idx, res_idxs) + + mask_selected = idx_block == current_max_idx + updated_working_val = ntl.where(mask_selected, sentinel, working_val) + working_val = ntl.cast(updated_working_val, working_val.dtype) + + values[0] = res_vals + indices[0] = res_idxs + + +def premake( + ndim, dim, k, largest, sorted=True, dtype=None, indices_dtype=None, block_size=None +): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + pad_val = float("-inf") if largest else float("inf") + + tensors = ( + Tensor(ndim, dtype=dtype, other=pad_val), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=indices_dtype), + Tensor(0, constexpr=True, value=k), + Tensor(0, constexpr=True, value=largest), + ) + + return arrangement_, application, tensors + + +premake_all_elements = premake diff --git a/src/ntops/kernels/var.py b/src/ntops/kernels/var.py new file mode 100644 index 0000000..d7e5fe3 --- /dev/null +++ b/src/ntops/kernels/var.py @@ -0,0 +1,41 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, output_var, num_elements, correction): + acc_sum = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + for i in range(input.shape[0]): + acc_sum += ntl.cast(input[i], ntl.float32) + + n_float = ntl.cast(num_elements, ntl.float32) + mean = ntl.sum(acc_sum, 0) / n_float + + acc_sq_diff = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + for i in range(input.shape[0]): + val_f32 = ntl.cast(input[i], ntl.float32) + diff = val_f32 - mean + mask = input[i].offsets(-1) < num_elements + diff = ntl.where(mask, diff, 0.0) + acc_sq_diff += diff * diff + + sum_sq_diff = ntl.sum(acc_sq_diff, 0) + + divisor = ntl.cast(num_elements - correction, ntl.float32) + var = sum_sq_diff / divisor + + output_var[0] = ntl.cast(var, output_var.dtype.dtype) + + +def premake(ndim, dim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + tensors = ( + Tensor(ndim, other=0, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=dtype), + Tensor(0, dtype=dtype), + ) + return arrangement_, application, tensors diff --git a/src/ntops/kernels/var_mean.py b/src/ntops/kernels/var_mean.py new file mode 100644 index 0000000..99df430 --- /dev/null +++ b/src/ntops/kernels/var_mean.py @@ -0,0 +1,43 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, output_var, output_mean, num_elements, correction): + acc_sum = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + for i in range(input.shape[0]): + acc_sum += ntl.cast(input[i], ntl.float32) + + n_float = ntl.cast(num_elements, ntl.float32) + mean = ntl.sum(acc_sum, 0) / n_float + + acc_sq_diff = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + for i in range(input.shape[0]): + val = ntl.cast(input[i], ntl.float32) + diff = val - mean + mask = input[i].offsets(-1) < num_elements + diff = ntl.where(mask, diff, 0) + acc_sq_diff += diff * diff + + sum_sq_diff = ntl.sum(acc_sq_diff, 0) + + divisor = ntl.cast(num_elements - correction, ntl.float32) + var = sum_sq_diff / divisor + + output_var[0] = ntl.cast(var, output_var.dtype.dtype) + output_mean[0] = ntl.cast(mean, output_mean.dtype.dtype) + + +def premake(ndim, dim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + tensors = ( + Tensor(ndim, other=0, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=dtype), + Tensor(0, dtype=dtype), + ) + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..8114a44 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -1,6 +1,7 @@ from ntops.torch.abs import abs from ntops.torch.add import add from ntops.torch.addmm import addmm +from ntops.torch.all import all from ntops.torch.bitwise_and import bitwise_and from ntops.torch.bitwise_not import bitwise_not from ntops.torch.bitwise_or import bitwise_or @@ -35,7 +36,11 @@ from ntops.torch.sin import sin from ntops.torch.softmax import softmax from ntops.torch.sub import sub +from ntops.torch.sum import sum from ntops.torch.tanh import tanh +from ntops.torch.topk import topk +from ntops.torch.var import var +from ntops.torch.var_mean import var_mean __all__ = [ "abs", @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "sum", + "topk", + "var", + "var_mean", + "all", ] diff --git a/src/ntops/torch/all.py b/src/ntops/torch/all.py new file mode 100644 index 0000000..c6cc47b --- /dev/null +++ b/src/ntops/torch/all.py @@ -0,0 +1,91 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + if target_size > 1024: + target_size = 1024 + if target_size < 32: + target_size = 32 + return target_size + + +def all( + input, dim: int | tuple[int] | list[int] | None = None, keepdim=False, *, out=None +): + output_dtype = torch.bool + + if dim is None: + dims = tuple(range(input.ndim)) + elif isinstance(dim, int): + dims = (dim,) + else: + dims = tuple(dim) + + if len(dims) == 0: + if out is not None: + out.copy_(input) + return out + return input.clone() + + if len(dims) > 1: + res = input + sorted_dims = sorted(dims, reverse=True) + for d in sorted_dims: + res = all(res, dim=d, keepdim=True) + + if dim is None: + res = res.view(()) + elif not keepdim: + for d in sorted_dims: + res = res.squeeze(d) + + if out is not None: + out.copy_(res) + return out + return res + + target_dim = dims[0] % input.ndim + + if keepdim: + output_shape = list(input.shape) + output_shape[target_dim] = 1 + else: + output_shape = list(input.shape) + output_shape.pop(target_dim) + + if out is not None: + values = out + else: + values = torch.empty(output_shape, dtype=output_dtype, device=input.device) + + values_keepdim_shape = list(input.shape) + values_keepdim_shape[target_dim] = 1 + values_for_kernel = values.view(values_keepdim_shape) + + kernel_ndim = input.ndim + reduction_size = input.shape[target_dim] + block_size = get_optimal_block_size(reduction_size) + + kernel = _cached_make( + ntops.kernels.all.premake, kernel_ndim, target_dim, block_size + ) + + kernel(input, values_for_kernel) + + if dim is None: + result = values.view(()) + if out is not None and values.data_ptr() != out.data_ptr(): + out.copy_(result) + return result + + return values diff --git a/src/ntops/torch/sum.py b/src/ntops/torch/sum.py new file mode 100644 index 0000000..4d82f4f --- /dev/null +++ b/src/ntops/torch/sum.py @@ -0,0 +1,100 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + + if target_size > 1024: + target_size = 1024 + + if target_size < 32: + target_size = 32 + + return target_size + + +def sum( + input, + dim: int | tuple[int] | list[int] | None = None, + keepdim=False, + *, + dtype=None, + out=None, +): + if dtype is None: + dtype = input.dtype + + if dim is None: + current = input + block_size = get_optimal_block_size(current.numel()) + + while current.numel() > 1: + output_len = math.ceil(current.numel() / block_size) + output = torch.empty((output_len,), dtype=dtype, device=current.device) + + kernel = _cached_make( + ntops.kernels.sum.premake_all_elements, + current.ndim, + current.dtype, + block_size, + ) + kernel(current, output) + current = output + + result = current.view(()) + + if out is not None: + out.copy_(result) + return out + + return result + else: + if isinstance(dim, int): + dims = (dim,) + else: + dims = tuple(dim) + + output_shape = list(input.shape) + for d in dims: + if d < 0: + d += input.ndim + output_shape[d] = 1 + + temp_out = torch.empty(output_shape, dtype=dtype, device=input.device) + block_size = get_optimal_block_size(output_shape[dims[0]]) + + kernel = _cached_make( + ntops.kernels.sum.premake, input.ndim, dims, dtype, block_size + ) + kernel(input, temp_out) + + if not keepdim: + dims_to_remove = sorted( + [d if d >= 0 else d + input.ndim for d in dims], reverse=True + ) + + final_shape = list(output_shape) + for d in dims_to_remove: + del final_shape[d] + + if not final_shape: + temp_out = temp_out.view(()) + else: + temp_out = temp_out.view(final_shape) + + if out is not None: + out.copy_(temp_out) + return out + + return temp_out diff --git a/src/ntops/torch/topk.py b/src/ntops/torch/topk.py new file mode 100644 index 0000000..4063cef --- /dev/null +++ b/src/ntops/torch/topk.py @@ -0,0 +1,82 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + + if target_size > 1024: + target_size = 1024 + + if target_size < 32: + target_size = 32 + + return target_size + + +def topk(input, k, dim=None, largest=True, sorted=True, *, out=None): + dtype = input.dtype + indices_dtype = torch.int64 + + if dim is None: + input_logic = input.contiguous().flatten() + target_dim = 0 + original_output_shape = (k,) + else: + input_logic = input + if dim < 0: + dim += input.ndim + target_dim = dim + original_output_shape = list(input.shape) + original_output_shape[dim] = k + + dim_size = input_logic.shape[target_dim] + block_size = get_optimal_block_size(dim_size) + + if out is not None: + values, indices = out + if dim is None: + values_logic = values.view(-1) + indices_logic = indices.view(-1) + else: + values_logic = values + indices_logic = indices + else: + logic_output_shape = list(input_logic.shape) + logic_output_shape[target_dim] = k + values_logic = torch.empty(logic_output_shape, dtype=dtype, device=input.device) + indices_logic = torch.empty( + logic_output_shape, dtype=indices_dtype, device=input.device + ) + + kernel = _cached_make( + ntops.kernels.topk.premake, + input_logic.ndim, + target_dim, + k, + largest, + sorted, + dtype, + indices_dtype, + block_size, + ) + + kernel(input_logic, values_logic, indices_logic, k, largest) + + if out is None: + if dim is None: + return values_logic.view(original_output_shape), indices_logic.view( + original_output_shape + ) + else: + return values_logic, indices_logic + + return values, indices diff --git a/src/ntops/torch/var.py b/src/ntops/torch/var.py new file mode 100644 index 0000000..0f2037d --- /dev/null +++ b/src/ntops/torch/var.py @@ -0,0 +1,78 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + if target_size > 1024: + target_size = 1024 + if target_size < 32: + target_size = 32 + return target_size + + +def var(input, dim=None, correction=1, keepdim=False, *, dtype=None, out=None): + if dtype is not None and input.dtype != dtype: + input = input.to(dtype) + + ndim = input.ndim + if dim is None: + target_dims = tuple(range(ndim)) + elif isinstance(dim, int): + target_dims = (dim,) + else: + target_dims = tuple(dim) + + target_dims = tuple(d if d >= 0 else d + ndim for d in target_dims) + + non_target_dims = [i for i in range(ndim) if i not in target_dims] + permuted_order = non_target_dims + list(target_dims) + + input_permuted = input.permute(permuted_order).contiguous() + + num_non_target = len(non_target_dims) + new_target_dims = tuple(range(num_non_target, ndim)) + + num_elements = 1 + for d in new_target_dims: + num_elements *= input_permuted.shape[d] + + kernel_out_shape = list(input_permuted.shape) + for d in new_target_dims: + kernel_out_shape[d] = 1 + + temp_var = torch.empty(kernel_out_shape, dtype=input.dtype, device=input.device) + block_size = get_optimal_block_size(num_elements) + + kernel = _cached_make( + ntops.kernels.var.premake, + input_permuted.ndim, + new_target_dims, + input_permuted.dtype, + block_size, + ) + + kernel(input_permuted, temp_var, num_elements, correction) + + if keepdim: + final_shape = list(input.shape) + for d in target_dims: + final_shape[d] = 1 + else: + final_shape = [input.shape[i] for i in non_target_dims] + + res_var = temp_var.view(final_shape) if final_shape else temp_var.view([]) + + if out is not None: + out.copy_(res_var) + return out + + return res_var diff --git a/src/ntops/torch/var_mean.py b/src/ntops/torch/var_mean.py new file mode 100644 index 0000000..de21252 --- /dev/null +++ b/src/ntops/torch/var_mean.py @@ -0,0 +1,83 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + if target_size > 1024: + target_size = 1024 + if target_size < 32: + target_size = 32 + return target_size + + +def var_mean(input, dim=None, correction=1, keepdim=False, *, dtype=None, out=None): + if dtype is not None and input.dtype != dtype: + input = input.to(dtype) + + ndim = input.ndim + if dim is None: + target_dims = tuple(range(ndim)) + elif isinstance(dim, int): + target_dims = (dim,) + else: + target_dims = tuple(dim) + + target_dims = tuple(d if d >= 0 else d + ndim for d in target_dims) + + non_target_dims = [i for i in range(ndim) if i not in target_dims] + permuted_order = non_target_dims + list(target_dims) + + input_permuted = input.permute(permuted_order).contiguous() + + num_non_target = len(non_target_dims) + new_target_dims = tuple(range(num_non_target, ndim)) + + num_elements = 1 + for d in new_target_dims: + num_elements *= input_permuted.shape[d] + + kernel_out_shape = list(input_permuted.shape) + for d in new_target_dims: + kernel_out_shape[d] = 1 + + temp_var = torch.empty(kernel_out_shape, dtype=input.dtype, device=input.device) + temp_mean = torch.empty(kernel_out_shape, dtype=input.dtype, device=input.device) + block_size = get_optimal_block_size(num_elements) + + kernel = _cached_make( + ntops.kernels.var_mean.premake, + input_permuted.ndim, + new_target_dims, + input_permuted.dtype, + block_size, + ) + + kernel(input_permuted, temp_var, temp_mean, num_elements, correction) + + if keepdim: + final_shape = list(input.shape) + for d in target_dims: + final_shape[d] = 1 + else: + result_shape = [input.shape[i] for i in non_target_dims] + final_shape = result_shape + + res_var = temp_var.view(final_shape) if final_shape else temp_var.view([]) + res_mean = temp_mean.view(final_shape) if final_shape else temp_mean.view([]) + + if out is not None: + out_var, out_mean = out + out_var.copy_(res_var) + out_mean.copy_(res_mean) + return out_var, out_mean + + return res_var, res_mean diff --git a/tests/test_all.py b/tests/test_all.py new file mode 100644 index 0000000..d964e0d --- /dev/null +++ b/tests/test_all.py @@ -0,0 +1,36 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("keepdim", [False, True]) +def test_all(shape, dtype, device, rtol, atol, keepdim): + if dtype == torch.bool: + input_tensor = torch.randint(0, 2, shape, device=device).bool() + else: + input_tensor = torch.randn(shape, dtype=dtype, device=device) + mask = torch.rand(shape, device=device) < 0.2 + input_tensor[mask] = 0 + + if random.random() < 0.2: + dim = None + else: + dim = random.randint(0, len(shape) - 1) + if random.choice([True, False]): + dim -= len(shape) + + ntops_res = ntops.torch.all(input_tensor, dim=dim, keepdim=keepdim) + + if dim is None: + ref_res = torch.all(input_tensor) + else: + ref_res = torch.all(input_tensor, dim=dim, keepdim=keepdim) + + assert torch.equal(ntops_res, ref_res) diff --git a/tests/test_sum.py b/tests/test_sum.py new file mode 100644 index 0000000..3f667ef --- /dev/null +++ b/tests/test_sum.py @@ -0,0 +1,35 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("keepdim", (False, True)) +def test_sum_dim(shape, dtype, device, rtol, atol, keepdim): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(0, input_tensor.ndim - 1) + + if random.choice((True, False)): + dim = dim - input_tensor.ndim + + ntops_value = ntops.torch.sum(input_tensor, dim=dim, keepdim=keepdim) + reference_value = torch.sum(input_tensor, dim=dim, keepdim=keepdim) + + assert torch.allclose(ntops_value, reference_value, rtol=rtol, atol=atol) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_sum_global(shape, dtype, device, rtol, atol): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + + ntops_value = ntops.torch.sum(input_tensor) + reference_value = torch.sum(input_tensor) + + assert torch.allclose(ntops_value, reference_value, rtol=rtol, atol=atol) diff --git a/tests/test_topk.py b/tests/test_topk.py new file mode 100644 index 0000000..24a5839 --- /dev/null +++ b/tests/test_topk.py @@ -0,0 +1,68 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import _random_shape + + +def generate_topk_args_dim(): + args = [] + for dtype in (torch.float32, torch.float16): + device = "cuda" + rtol, atol = (1e-3, 1e-3) if dtype == torch.float32 else (1e-2, 1e-2) + + for ndim in range(1, 4): + for _ in range(5): + shape = _random_shape(ndim) + dim = random.randint(0, ndim - 1) + dim_size = shape[dim] + + k = random.randint(1, min(dim_size, 128)) + args.append((shape, k, dim, dtype, device, rtol, atol)) + return "shape, k, dim, dtype, device, rtol, atol", args + + +def generate_topk_args_global(): + args = [] + for dtype in (torch.float32, torch.float16): + device = "cuda" + rtol, atol = (1e-3, 1e-3) if dtype == torch.float32 else (1e-2, 1e-2) + + candidates = [(100,), (10, 20), (5, 5, 5)] + for shape in candidates: + numel = 1 + for s in shape: + numel *= s + k = random.randint(1, min(numel, 64)) + args.append((shape, k, dtype, device, rtol, atol)) + return "shape, k, dtype, device, rtol, atol", args + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_topk_args_dim()) +@pytest.mark.parametrize("largest", [True, False]) +def test_topk_dim(shape, k, dim, dtype, device, rtol, atol, largest): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + + ntops_v, ntops_i = ntops.torch.topk(input_tensor, k, dim=dim, largest=largest) + ref_v, ref_i = torch.topk(input_tensor, k, dim=dim, largest=largest) + + assert torch.allclose(ntops_v, ref_v, rtol=rtol, atol=atol) + gathered = torch.gather(input_tensor, dim, ntops_i) + assert torch.allclose(gathered, ntops_v, rtol=rtol, atol=atol) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_topk_args_global()) +@pytest.mark.parametrize("largest", [True, False]) +def test_topk_global(shape, k, dtype, device, rtol, atol, largest): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + + ntops_v, ntops_i = ntops.torch.topk(input_tensor, k, dim=None, largest=largest) + + ref_v, ref_i = torch.topk(input_tensor.flatten(), k, dim=0, largest=largest) + + assert torch.allclose(ntops_v, ref_v, rtol=rtol, atol=atol) diff --git a/tests/test_var.py b/tests/test_var.py new file mode 100644 index 0000000..976edaa --- /dev/null +++ b/tests/test_var.py @@ -0,0 +1,42 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("correction", [0, 1]) +@pytest.mark.parametrize("keepdim", [False, True]) +def test_var_dim(shape, dtype, device, rtol, atol, correction, keepdim): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(0, input_tensor.ndim - 1) + + if random.choice((True, False)): + dim = dim - input_tensor.ndim + + ntops_value = ntops.torch.var( + input_tensor, dim=dim, correction=correction, keepdim=keepdim + ) + + reference_value = torch.var( + input_tensor, dim=dim, correction=correction, keepdim=keepdim + ) + + assert torch.allclose(ntops_value, reference_value, rtol=rtol, atol=atol) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("correction", [0, 1]) +def test_var_global(shape, dtype, device, rtol, atol, correction): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + + ntops_value = ntops.torch.var(input_tensor, correction=correction) + reference_value = torch.var(input_tensor, correction=correction) + + assert torch.allclose(ntops_value, reference_value, rtol=rtol, atol=atol) diff --git a/tests/test_var_mean.py b/tests/test_var_mean.py new file mode 100644 index 0000000..44a7b4c --- /dev/null +++ b/tests/test_var_mean.py @@ -0,0 +1,41 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("correction", [0, 1]) +@pytest.mark.parametrize("keepdim", [False, True]) +def test_var_mean_general(keepdim, correction, shape, dtype, device, rtol, atol): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(0, input_tensor.ndim - 1) if input_tensor.ndim > 0 else None + if dim is not None and random.choice((True, False)): + dim = dim - input_tensor.ndim + + nt_var, nt_mean = ntops.torch.var_mean( + input_tensor, dim=dim, correction=correction, keepdim=keepdim + ) + ref_var, ref_mean = torch.var_mean( + input_tensor, dim=dim, correction=correction, keepdim=keepdim + ) + + assert torch.allclose(nt_var, ref_var, rtol=rtol, atol=atol, equal_nan=True) + assert torch.allclose(nt_mean, ref_mean, rtol=rtol, atol=atol, equal_nan=True) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_var_mean_global(shape, dtype, device, rtol, atol): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + + nt_var, nt_mean = ntops.torch.var_mean(input_tensor, dim=None) + ref_var, ref_mean = torch.var_mean(input_tensor, dim=None) + + assert torch.allclose(nt_var, ref_var, rtol=rtol, atol=atol, equal_nan=True) + assert torch.allclose(nt_mean, ref_mean, rtol=rtol, atol=atol, equal_nan=True)