diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..e79f526 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -33,6 +33,8 @@ sigmoid, silu, sin, + logical_and, + logical_not, softmax, sub, tanh, @@ -59,6 +61,8 @@ "isnan", "layer_norm", "le", + "logical_not", + "logical_and", "lt", "mm", "mul", diff --git a/src/ntops/kernels/logical_and.py b/src/ntops/kernels/logical_and.py new file mode 100644 index 0000000..936230e --- /dev/null +++ b/src/ntops/kernels/logical_and.py @@ -0,0 +1,18 @@ +import functools +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ntops.kernels.element_wise import arrangement + +def application(input1, input2, output): + # 获取输入的数据类型 + dtype = input1.dtype + val1 = ntl.cast(input1, ntl.int1) + val2 = ntl.cast(input2, ntl.int1) + result = val1 & val2 + output = ntl.cast(result, dtype) + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/logical_not.py b/src/ntops/kernels/logical_not.py new file mode 100644 index 0000000..0df5090 --- /dev/null +++ b/src/ntops/kernels/logical_not.py @@ -0,0 +1,18 @@ +import functools +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ntops.kernels.element_wise import arrangement + +def application(input, output): + val_in = input + val_bool = ntl.cast(val_in, ntl.int1) + result_bool = ~val_bool + val_out = ntl.cast(result_bool, output.dtype) + + # 4. 赋值给输出 + output = val_out + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..5bf43a8 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -36,7 +36,8 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh - +from ntops.torch.logical_and import logical_and +from ntops.torch.logical_not import logical_not __all__ = [ "abs", "add", @@ -71,6 +72,8 @@ "rsqrt", "scaled_dot_product_attention", "sigmoid", + "logical_not", + "logical_and" "silu", "sin", "softmax", diff --git a/src/ntops/torch/logical_and.py b/src/ntops/torch/logical_and.py new file mode 100644 index 0000000..1c4a01f --- /dev/null +++ b/src/ntops/torch/logical_and.py @@ -0,0 +1,11 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def logical_and(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + kernel = _cached_make(ntops.kernels.logical_and.premake, input.ndim) + kernel(input, other, out) + + return out \ No newline at end of file diff --git a/src/ntops/torch/logical_not.py b/src/ntops/torch/logical_not.py new file mode 100644 index 0000000..b9fcbd1 --- /dev/null +++ b/src/ntops/torch/logical_not.py @@ -0,0 +1,29 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def logical_not(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + # 判定是否为“非连续内存的原地操作” + is_inplace = (out.data_ptr() == input.data_ptr()) + is_strided = (not input.is_contiguous()) + + if is_inplace and is_strided: + # 创建一个连续的临时 Tensor + # 使用 torch.empty 确保它是连续的,避免继承 input 的非连续 stride + temp_out = torch.empty(input.shape, dtype=input.dtype, device=input.device) + + kernel = _cached_make(ntops.kernels.logical_not.premake, input.ndim) + + # 1. 读非连续 input -> 写连续 temp_out (安全) + kernel(input, temp_out) + + # 2. 复制回原处 (PyTorch 会自动处理 stride 转换) + out.copy_(temp_out) + else: + kernel = _cached_make(ntops.kernels.logical_not.premake, input.ndim) + kernel(input, out) + + return out \ No newline at end of file diff --git a/tests/test_logical_and.py b/tests/test_logical_and.py new file mode 100644 index 0000000..bdabb86 --- /dev/null +++ b/tests/test_logical_and.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_logical_and(shape, dtype, device, rtol, atol): + input = torch.randint(0, 2, shape, device=device).to(dtype) + other = torch.randint(0, 2, shape, device=device).to(dtype) + + ninetoothed_output = ntops.torch.logical_and(input, other) + reference_output = torch.logical_and(input, other).to(dtype) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) \ No newline at end of file diff --git a/tests/test_logical_not.py b/tests/test_logical_not.py new file mode 100644 index 0000000..54afbc0 --- /dev/null +++ b/tests/test_logical_not.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_logical_not(shape, dtype, device, rtol, atol): + if dtype is torch.float16: + return + input = torch.randint(0, 2, shape, device=device).to(dtype) + ninetoothed_output = ntops.torch.logical_not(input) + reference_output = torch.logical_not(input).to(dtype) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) \ No newline at end of file