Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
abs,
add,
addmm,
all,
bitwise_and,
bitwise_not,
bitwise_or,
Expand Down Expand Up @@ -35,7 +36,11 @@
sin,
softmax,
sub,
sum,
tanh,
topk,
var,
var_mean,
)

__all__ = [
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"sum",
"topk",
"var",
"var_mean",
"all",
]
22 changes: 22 additions & 0 deletions src/ntops/kernels/all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def application(input, output):
val_block = input[0]
bool_block = val_block != 0
res = ntl.min(bool_block, axis=0)
output[0] = res


def premake(ndim, dim, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)
tensors = (
Tensor(ndim, other=1),
Tensor(ndim, dtype="int8"),
)
return arrangement_, application, tensors
48 changes: 48 additions & 0 deletions src/ntops/kernels/sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def application(input, output):
accumulator = 0.0

for i in range(input.shape[0]):
block_sum = ntl.sum(input[i], axis=0)
accumulator += block_sum

output[0] = ntl.cast(accumulator, output.dtype.dtype)


def premake(ndim, dim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors


def arrangement_all_elements(input, output, block_size=None):
input = input.flatten().tile((block_size,))
output = output.tile((1,))
return input, output


def application_all_elements(input, output):
output[0] = ntl.sum(input, 0)


def premake_all_elements(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement_all_elements, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(1, dtype=dtype),
)

return arrangement_, application_all_elements, tensors
62 changes: 62 additions & 0 deletions src/ntops/kernels/topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def application(input, values, indices, k, largest):
val_block = input[0]

idx_block = ntl.arange(0, val_block.shape[0])

res_vals = ntl.zeros(val_block.shape, dtype=val_block.dtype)
res_idxs = ntl.zeros(val_block.shape, dtype=indices.dtype.dtype)
output_range = ntl.arange(0, val_block.shape[0])

if largest:
working_val = val_block
else:
working_val = -val_block

sentinel = float("-inf")

for i in range(k):
current_max_val = ntl.max(working_val, axis=0)
current_max_idx = ntl.argmax(working_val, axis=0)

real_val = -current_max_val if not largest else current_max_val
real_val = ntl.cast(real_val, res_vals.dtype)

target_mask = output_range == i
res_vals = ntl.where(target_mask, real_val, res_vals)
res_idxs = ntl.where(target_mask, current_max_idx, res_idxs)

mask_selected = idx_block == current_max_idx
updated_working_val = ntl.where(mask_selected, sentinel, working_val)
working_val = ntl.cast(updated_working_val, working_val.dtype)

values[0] = res_vals
indices[0] = res_idxs


def premake(
ndim, dim, k, largest, sorted=True, dtype=None, indices_dtype=None, block_size=None
):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

pad_val = float("-inf") if largest else float("inf")

tensors = (
Tensor(ndim, dtype=dtype, other=pad_val),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=indices_dtype),
Tensor(0, constexpr=True, value=k),
Tensor(0, constexpr=True, value=largest),
)

return arrangement_, application, tensors


premake_all_elements = premake
41 changes: 41 additions & 0 deletions src/ntops/kernels/var.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def application(input, output_var, num_elements, correction):
acc_sum = ntl.zeros(input.dtype.shape, dtype=ntl.float32)
for i in range(input.shape[0]):
acc_sum += ntl.cast(input[i], ntl.float32)

n_float = ntl.cast(num_elements, ntl.float32)
mean = ntl.sum(acc_sum, 0) / n_float

acc_sq_diff = ntl.zeros(input.dtype.shape, dtype=ntl.float32)
for i in range(input.shape[0]):
val_f32 = ntl.cast(input[i], ntl.float32)
diff = val_f32 - mean
mask = input[i].offsets(-1) < num_elements
diff = ntl.where(mask, diff, 0.0)
acc_sq_diff += diff * diff

sum_sq_diff = ntl.sum(acc_sq_diff, 0)

divisor = ntl.cast(num_elements - correction, ntl.float32)
var = sum_sq_diff / divisor

output_var[0] = ntl.cast(var, output_var.dtype.dtype)


def premake(ndim, dim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)
tensors = (
Tensor(ndim, other=0, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=dtype),
Tensor(0, dtype=dtype),
)
return arrangement_, application, tensors
43 changes: 43 additions & 0 deletions src/ntops/kernels/var_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def application(input, output_var, output_mean, num_elements, correction):
acc_sum = ntl.zeros(input.dtype.shape, dtype=ntl.float32)
for i in range(input.shape[0]):
acc_sum += ntl.cast(input[i], ntl.float32)

n_float = ntl.cast(num_elements, ntl.float32)
mean = ntl.sum(acc_sum, 0) / n_float

acc_sq_diff = ntl.zeros(input.dtype.shape, dtype=ntl.float32)
for i in range(input.shape[0]):
val = ntl.cast(input[i], ntl.float32)
diff = val - mean
mask = input[i].offsets(-1) < num_elements
diff = ntl.where(mask, diff, 0)
acc_sq_diff += diff * diff

sum_sq_diff = ntl.sum(acc_sq_diff, 0)

divisor = ntl.cast(num_elements - correction, ntl.float32)
var = sum_sq_diff / divisor

output_var[0] = ntl.cast(var, output_var.dtype.dtype)
output_mean[0] = ntl.cast(mean, output_mean.dtype.dtype)


def premake(ndim, dim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)
tensors = (
Tensor(ndim, other=0, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=dtype),
Tensor(0, dtype=dtype),
)
return arrangement_, application, tensors
10 changes: 10 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ntops.torch.abs import abs
from ntops.torch.add import add
from ntops.torch.addmm import addmm
from ntops.torch.all import all
from ntops.torch.bitwise_and import bitwise_and
from ntops.torch.bitwise_not import bitwise_not
from ntops.torch.bitwise_or import bitwise_or
Expand Down Expand Up @@ -35,7 +36,11 @@
from ntops.torch.sin import sin
from ntops.torch.softmax import softmax
from ntops.torch.sub import sub
from ntops.torch.sum import sum
from ntops.torch.tanh import tanh
from ntops.torch.topk import topk
from ntops.torch.var import var
from ntops.torch.var_mean import var_mean

__all__ = [
"abs",
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"sum",
"topk",
"var",
"var_mean",
"all",
]
91 changes: 91 additions & 0 deletions src/ntops/torch/all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def next_power_of_2(n):
if n == 0:
return 1
return 1 << (n - 1).bit_length()


def get_optimal_block_size(dim_size):
target_size = next_power_of_2(dim_size)
if target_size > 1024:
target_size = 1024
if target_size < 32:
target_size = 32
return target_size


def all(
input, dim: int | tuple[int] | list[int] | None = None, keepdim=False, *, out=None
):
output_dtype = torch.bool

if dim is None:
dims = tuple(range(input.ndim))
elif isinstance(dim, int):
dims = (dim,)
else:
dims = tuple(dim)

if len(dims) == 0:
if out is not None:
out.copy_(input)
return out
return input.clone()

if len(dims) > 1:
res = input
sorted_dims = sorted(dims, reverse=True)
for d in sorted_dims:
res = all(res, dim=d, keepdim=True)

if dim is None:
res = res.view(())
elif not keepdim:
for d in sorted_dims:
res = res.squeeze(d)

if out is not None:
out.copy_(res)
return out
return res

target_dim = dims[0] % input.ndim

if keepdim:
output_shape = list(input.shape)
output_shape[target_dim] = 1
else:
output_shape = list(input.shape)
output_shape.pop(target_dim)

if out is not None:
values = out
else:
values = torch.empty(output_shape, dtype=output_dtype, device=input.device)

values_keepdim_shape = list(input.shape)
values_keepdim_shape[target_dim] = 1
values_for_kernel = values.view(values_keepdim_shape)

kernel_ndim = input.ndim
reduction_size = input.shape[target_dim]
block_size = get_optimal_block_size(reduction_size)

kernel = _cached_make(
ntops.kernels.all.premake, kernel_ndim, target_dim, block_size
)

kernel(input, values_for_kernel)

if dim is None:
result = values.view(())
if out is not None and values.data_ptr() != out.data_ptr():
out.copy_(result)
return result

return values
Loading