Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
softmax,
sub,
tanh,
logical_or,
logical_xor,
logsigmoid,
where,
)

__all__ = [
Expand Down Expand Up @@ -76,4 +80,9 @@
"softmax",
"sub",
"tanh",
"logical_or",
"logical_xor",
"logsigmoid",
"where",

]
21 changes: 21 additions & 0 deletions src/ntops/kernels/logical_or.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import functools

from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, other, output):
output = (input != 0) | (other != 0) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
20 changes: 20 additions & 0 deletions src/ntops/kernels/logical_xor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import functools

from ninetoothed import Tensor
from ntops.kernels.element_wise import arrangement


def application(input, other, output):
output = (input != 0) ^ (other != 0) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
27 changes: 27 additions & 0 deletions src/ntops/kernels/logsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
x = ntl.cast(input, ntl.float32)
y = ntl.log(ntl.sigmoid(x))

# 再 cast 回输出 dtype(保持与输入/torch 行为一致)
output = ntl.cast(y, output.dtype) # noqa: F841
# # logsigmoid(x) = log(sigmoid(x))
# output = ntl.log(ntl.sigmoid(input)) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
25 changes: 25 additions & 0 deletions src/ntops/kernels/where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(condition, input, other, output):
# condition 非 0 当 True,语义对齐 torch.where
cond_bool = condition != 0
output = ntl.where(cond_bool, input, other) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype), # condition
Tensor(ndim, dtype=dtype), # input
Tensor(ndim, dtype=dtype), # other
Tensor(ndim, dtype=dtype), # output
)

return arrangement_, application, tensors
10 changes: 10 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
from ntops.torch.softmax import softmax
from ntops.torch.sub import sub
from ntops.torch.tanh import tanh
from ntops.torch.logical_or import logical_or
from ntops.torch.logical_xor import logical_xor
from ntops.torch.logsigmoid import logsigmoid
from ntops.torch.where import where


__all__ = [
"abs",
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"logical_or",
"logical_xor",
"logsigmoid",
"where",

]
16 changes: 16 additions & 0 deletions src/ntops/torch/logical_or.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# src/ntops/torch/logical_or.py
import torch

import ntops
from ntops.torch.utils import _cached_make


def logical_or(input, other, *, out=None):
if out is None:
out = torch.empty_like(input, dtype=torch.bool)

kernel = _cached_make(ntops.kernels.logical_or.premake, input.ndim)

kernel(input, other, out)

return out
35 changes: 35 additions & 0 deletions src/ntops/torch/logical_xor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# import torch

# import ntops
# from ntops.torch.utils import _cached_make

# def logical_xor(input, other, *, out=None):
# if out is None:
# out = torch.empty_like(input, dtype=torch.bool)

# kernel = _cached_make(ntops.kernels.logical_xor.premake, input.ndim)
# kernel(input, other, out)
# return out

import torch

import ntops
from ntops.torch.utils import _cached_make


def logical_xor(input, other, *, out=None):
kernel = _cached_make(ntops.kernels.logical_xor.premake, input.ndim)

if out is None:
out = torch.empty_like(input, dtype=torch.bool)
kernel(input, other, out)
return out

if out is input or out is other:
tmp = torch.empty_like(out)
kernel(input, other, tmp)
out.copy_(tmp)
else:
kernel(input, other, out)

return out
13 changes: 13 additions & 0 deletions src/ntops/torch/logsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def logsigmoid(input, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.logsigmoid.premake, input.ndim)
kernel(input, out)
return out
22 changes: 22 additions & 0 deletions src/ntops/torch/vdot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def vdot(input, other, *, out=None):
assert input.ndim == 1 and other.ndim == 1
assert input.shape[0] == other.shape[0]

if out is None:
out = torch.empty((), dtype=input.dtype, device=input.device)

# 创建一个临时的 1D tensor 作为 accumulator
accumulator = torch.empty((1,), dtype=input.dtype, device=input.device)

kernel = _cached_make(ntops.kernels.vdot.premake, input.ndim)
kernel(input, other, accumulator)

# 从 accumulator 提取标量值
out.copy_(accumulator[0])
return out
14 changes: 14 additions & 0 deletions src/ntops/torch/where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def where(condition, input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

# 这里假设 input/other/out dtype 一致
kernel = _cached_make(ntops.kernels.where.premake, input.ndim)
kernel(condition, input, other, out)
return out