Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
sigmoid,
silu,
sin,
logical_and,
logical_not,
softmax,
sub,
tanh,
Expand All @@ -59,6 +61,8 @@
"isnan",
"layer_norm",
"le",
"logical_not",
"logical_and",
"lt",
"mm",
"mul",
Expand Down
18 changes: 18 additions & 0 deletions src/ntops/kernels/logical_and.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ntops.kernels.element_wise import arrangement

def application(input1, input2, output):
# 获取输入的数据类型
dtype = input1.dtype
val1 = ntl.cast(input1, ntl.int1)
val2 = ntl.cast(input2, ntl.int1)
result = val1 & val2
output = ntl.cast(result, dtype)

def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
18 changes: 18 additions & 0 deletions src/ntops/kernels/logical_not.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ntops.kernels.element_wise import arrangement

def application(input, output):
val_in = input
val_bool = ntl.cast(val_in, ntl.int1)
result_bool = ~val_bool
val_out = ntl.cast(result_bool, output.dtype)

# 4. 赋值给输出
output = val_out

def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
return arrangement_, application, tensors
5 changes: 4 additions & 1 deletion src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
from ntops.torch.softmax import softmax
from ntops.torch.sub import sub
from ntops.torch.tanh import tanh

from ntops.torch.logical_and import logical_and
from ntops.torch.logical_not import logical_not
__all__ = [
"abs",
"add",
Expand Down Expand Up @@ -71,6 +72,8 @@
"rsqrt",
"scaled_dot_product_attention",
"sigmoid",
"logical_not",
"logical_and"
"silu",
"sin",
"softmax",
Expand Down
11 changes: 11 additions & 0 deletions src/ntops/torch/logical_and.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
import ntops
from ntops.torch.utils import _cached_make

def logical_and(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)
kernel = _cached_make(ntops.kernels.logical_and.premake, input.ndim)
kernel(input, other, out)

return out
29 changes: 29 additions & 0 deletions src/ntops/torch/logical_not.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import ntops
from ntops.torch.utils import _cached_make

def logical_not(input, *, out=None):
if out is None:
out = torch.empty_like(input)

# 判定是否为“非连续内存的原地操作”
is_inplace = (out.data_ptr() == input.data_ptr())
is_strided = (not input.is_contiguous())

if is_inplace and is_strided:
# 创建一个连续的临时 Tensor
# 使用 torch.empty 确保它是连续的,避免继承 input 的非连续 stride
temp_out = torch.empty(input.shape, dtype=input.dtype, device=input.device)

kernel = _cached_make(ntops.kernels.logical_not.premake, input.ndim)

# 1. 读非连续 input -> 写连续 temp_out (安全)
kernel(input, temp_out)

# 2. 复制回原处 (PyTorch 会自动处理 stride 转换)
out.copy_(temp_out)
else:
kernel = _cached_make(ntops.kernels.logical_not.premake, input.ndim)
kernel(input, out)

return out
17 changes: 17 additions & 0 deletions tests/test_logical_and.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch

import ntops
from tests.skippers import skip_if_cuda_not_available
from tests.utils import generate_arguments

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_logical_and(shape, dtype, device, rtol, atol):
input = torch.randint(0, 2, shape, device=device).to(dtype)
other = torch.randint(0, 2, shape, device=device).to(dtype)

ninetoothed_output = ntops.torch.logical_and(input, other)
reference_output = torch.logical_and(input, other).to(dtype)

assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
17 changes: 17 additions & 0 deletions tests/test_logical_not.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch

import ntops
from tests.skippers import skip_if_cuda_not_available
from tests.utils import generate_arguments

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_logical_not(shape, dtype, device, rtol, atol):
if dtype is torch.float16:
return
input = torch.randint(0, 2, shape, device=device).to(dtype)
ninetoothed_output = ntops.torch.logical_not(input)
reference_output = torch.logical_not(input).to(dtype)

assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)