diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..c24a979 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -36,6 +36,10 @@ softmax, sub, tanh, + logical_or, + logical_xor, + logsigmoid, + where, ) __all__ = [ @@ -76,4 +80,9 @@ "softmax", "sub", "tanh", + "logical_or", + "logical_xor", + "logsigmoid", + "where", + ] diff --git a/src/ntops/kernels/logical_or.py b/src/ntops/kernels/logical_or.py new file mode 100644 index 0000000..85dc7f0 --- /dev/null +++ b/src/ntops/kernels/logical_or.py @@ -0,0 +1,21 @@ +import functools + +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + output = (input != 0) | (other != 0) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/logical_xor.py b/src/ntops/kernels/logical_xor.py new file mode 100644 index 0000000..60095f3 --- /dev/null +++ b/src/ntops/kernels/logical_xor.py @@ -0,0 +1,20 @@ +import functools + +from ninetoothed import Tensor +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + output = (input != 0) ^ (other != 0) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/logsigmoid.py b/src/ntops/kernels/logsigmoid.py new file mode 100644 index 0000000..2239c78 --- /dev/null +++ b/src/ntops/kernels/logsigmoid.py @@ -0,0 +1,27 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + x = ntl.cast(input, ntl.float32) + y = ntl.log(ntl.sigmoid(x)) + + # 再 cast 回输出 dtype(保持与输入/torch 行为一致) + output = ntl.cast(y, output.dtype) # noqa: F841 + # # logsigmoid(x) = log(sigmoid(x)) + # output = ntl.log(ntl.sigmoid(input)) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/where.py b/src/ntops/kernels/where.py new file mode 100644 index 0000000..3583fdc --- /dev/null +++ b/src/ntops/kernels/where.py @@ -0,0 +1,25 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(condition, input, other, output): + # condition 非 0 当 True,语义对齐 torch.where + cond_bool = condition != 0 + output = ntl.where(cond_bool, input, other) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), # condition + Tensor(ndim, dtype=dtype), # input + Tensor(ndim, dtype=dtype), # other + Tensor(ndim, dtype=dtype), # output + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..d11ff8d 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -36,6 +36,11 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.logical_or import logical_or +from ntops.torch.logical_xor import logical_xor +from ntops.torch.logsigmoid import logsigmoid +from ntops.torch.where import where + __all__ = [ "abs", @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "logical_or", + "logical_xor", + "logsigmoid", + "where", + ] diff --git a/src/ntops/torch/logical_or.py b/src/ntops/torch/logical_or.py new file mode 100644 index 0000000..07b9767 --- /dev/null +++ b/src/ntops/torch/logical_or.py @@ -0,0 +1,16 @@ +# src/ntops/torch/logical_or.py +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def logical_or(input, other, *, out=None): + if out is None: + out = torch.empty_like(input, dtype=torch.bool) + + kernel = _cached_make(ntops.kernels.logical_or.premake, input.ndim) + + kernel(input, other, out) + + return out \ No newline at end of file diff --git a/src/ntops/torch/logical_xor.py b/src/ntops/torch/logical_xor.py new file mode 100644 index 0000000..47128b3 --- /dev/null +++ b/src/ntops/torch/logical_xor.py @@ -0,0 +1,35 @@ +# import torch + +# import ntops +# from ntops.torch.utils import _cached_make + +# def logical_xor(input, other, *, out=None): +# if out is None: +# out = torch.empty_like(input, dtype=torch.bool) + +# kernel = _cached_make(ntops.kernels.logical_xor.premake, input.ndim) +# kernel(input, other, out) +# return out + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def logical_xor(input, other, *, out=None): + kernel = _cached_make(ntops.kernels.logical_xor.premake, input.ndim) + + if out is None: + out = torch.empty_like(input, dtype=torch.bool) + kernel(input, other, out) + return out + + if out is input or out is other: + tmp = torch.empty_like(out) + kernel(input, other, tmp) + out.copy_(tmp) + else: + kernel(input, other, out) + + return out \ No newline at end of file diff --git a/src/ntops/torch/logsigmoid.py b/src/ntops/torch/logsigmoid.py new file mode 100644 index 0000000..8454761 --- /dev/null +++ b/src/ntops/torch/logsigmoid.py @@ -0,0 +1,13 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def logsigmoid(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.logsigmoid.premake, input.ndim) + kernel(input, out) + return out \ No newline at end of file diff --git a/src/ntops/torch/vdot.py b/src/ntops/torch/vdot.py new file mode 100644 index 0000000..5fbb140 --- /dev/null +++ b/src/ntops/torch/vdot.py @@ -0,0 +1,22 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def vdot(input, other, *, out=None): + assert input.ndim == 1 and other.ndim == 1 + assert input.shape[0] == other.shape[0] + + if out is None: + out = torch.empty((), dtype=input.dtype, device=input.device) + + # 创建一个临时的 1D tensor 作为 accumulator + accumulator = torch.empty((1,), dtype=input.dtype, device=input.device) + + kernel = _cached_make(ntops.kernels.vdot.premake, input.ndim) + kernel(input, other, accumulator) + + # 从 accumulator 提取标量值 + out.copy_(accumulator[0]) + return out \ No newline at end of file diff --git a/src/ntops/torch/where.py b/src/ntops/torch/where.py new file mode 100644 index 0000000..762c817 --- /dev/null +++ b/src/ntops/torch/where.py @@ -0,0 +1,14 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def where(condition, input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + # 这里假设 input/other/out dtype 一致 + kernel = _cached_make(ntops.kernels.where.premake, input.ndim) + kernel(condition, input, other, out) + return out \ No newline at end of file