From fa28a7a4b3760f524ad03fea5fa556fc05153e10 Mon Sep 17 00:00:00 2001 From: greenhandhand <781740145@qq.com> Date: Sat, 13 Dec 2025 16:02:27 +0800 Subject: [PATCH] =?UTF-8?q?Infinicore=20=E6=AF=94=E8=B5=9B=EF=BC=8C?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20logsumexp,=20lp=5Fpool1d,=20lp=5Fpool2d,?= =?UTF-8?q?=20lp=5Fpool3d,=20max?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ntops/kernels/__init__.py | 10 ++ src/ntops/kernels/logsumexp.py | 45 +++++ src/ntops/kernels/lp_pool1d.py | 79 +++++++++ src/ntops/kernels/lp_pool2d.py | 163 ++++++++++++++++++ src/ntops/kernels/lp_pool3d.py | 160 ++++++++++++++++++ src/ntops/kernels/max.py | 61 +++++++ src/ntops/torch/__init__.py | 10 ++ src/ntops/torch/logsumexp.py | 27 +++ src/ntops/torch/lp_pool1d.py | 57 +++++++ src/ntops/torch/lp_pool2d.py | 65 ++++++++ src/ntops/torch/lp_pool3d.py | 71 ++++++++ src/ntops/torch/max.py | 49 ++++++ tests/test_logsumexp.py | 296 +++++++++++++++++++++++++++++++++ tests/test_lp_pool1d.py | 46 +++++ tests/test_lp_pool2d.py | 59 +++++++ tests/test_lp_pool3d.py | 63 +++++++ tests/test_max.py | 37 +++++ 17 files changed, 1298 insertions(+) create mode 100644 src/ntops/kernels/logsumexp.py create mode 100644 src/ntops/kernels/lp_pool1d.py create mode 100644 src/ntops/kernels/lp_pool2d.py create mode 100644 src/ntops/kernels/lp_pool3d.py create mode 100644 src/ntops/kernels/max.py create mode 100644 src/ntops/torch/logsumexp.py create mode 100644 src/ntops/torch/lp_pool1d.py create mode 100644 src/ntops/torch/lp_pool2d.py create mode 100644 src/ntops/torch/lp_pool3d.py create mode 100644 src/ntops/torch/max.py create mode 100644 tests/test_logsumexp.py create mode 100644 tests/test_lp_pool1d.py create mode 100644 tests/test_lp_pool2d.py create mode 100644 tests/test_lp_pool3d.py create mode 100644 tests/test_max.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..15c540f 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -36,6 +36,11 @@ softmax, sub, tanh, + logsumexp, + lp_pool1d, + lp_pool2d, + lp_pool3d, + max, ) __all__ = [ @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "logsumexp", + "lp_pool1d", + "lp_pool2d", + "lp_pool3d", + "max", ] diff --git a/src/ntops/kernels/logsumexp.py b/src/ntops/kernels/logsumexp.py new file mode 100644 index 0000000..44f6903 --- /dev/null +++ b/src/ntops/kernels/logsumexp.py @@ -0,0 +1,45 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def _exp(x, dtype): + exp_dtype = dtype if dtype != ntl.float16 else ntl.float32 + return ntl.cast(ntl.exp(ntl.cast(x, exp_dtype)), dtype) + +def _log(x, dtype): + log_dtype = dtype if dtype != ntl.float16 else ntl.float32 + return ntl.cast(ntl.log(ntl.cast(x, log_dtype)), dtype) + +def application(input, output): + # input&output: (C // block_size, ) + # input.dtype: (block_size, ) + dtype = output.dtype.dtype + prev_max = ntl.cast(float("-inf"), dtype) + denominator = ntl.cast(0, dtype) + + for i in range(input.shape[0]): + input_i = ntl.cast(input[i], dtype) + curr_max = ntl.cast(ntl.maximum(prev_max, ntl.max(input_i)), dtype) + input_max_diff_exp = _exp(input_i - curr_max, dtype) + prev_curr_max_diff_exp = _exp(prev_max - curr_max, dtype) + denominator = denominator * prev_curr_max_diff_exp + ntl.sum(input_max_diff_exp) + prev_max = curr_max + + output[0] = prev_max + _log(denominator, dtype) + + +def premake(ndim, dim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor( + ndim, dtype=dtype, other=float("-inf"), shape_options={"constexpr": True} + ), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/lp_pool1d.py b/src/ntops/kernels/lp_pool1d.py new file mode 100644 index 0000000..bec4beb --- /dev/null +++ b/src/ntops/kernels/lp_pool1d.py @@ -0,0 +1,79 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed.language import libdevice +from ninetoothed import Tensor +from ninetoothed import Symbol + + +def arrangement(input, output, norm_type, kernel_size_val, kernel_size, stride, block_size, ceil_mode): + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (N, C, L_in) output: (N, C, L_out) + + input_arranged = input.tile((1, 1, kernel_size), (1, 1, stride), floor_mode=not ceil_mode) + # => (N, C, L_out), dtype=(1, 1, k) + input_arranged = input_arranged.ravel() + # => (N, C, L_out, 1, 1, k) + input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1) + # => (N*C*L_out, k) + # k 的找到最近的 2 的倍数 + nearest_pow2 = 1 << (kernel_size - 1).bit_length() + input_arranged = input_arranged.tile((1, nearest_pow2)) + # => (..., k // nearest_pow2 = 1), dtype=(1, nearest_pow2) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + # => (..., 1), dtype=(nearest_pow2, ) + input_arranged = input_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1), dtype=(nearest_pow2, ) + input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1) + # => (..., 1), dtype=(block_size, nearest_pow2) + + output_arranged = output.tile((1, 1, 1)) + # => (N, C, L_out), dtype=(1, 1, 1) + output_arranged = output_arranged.ravel() + # => (N, C, L_out, 1, 1, 1) + output_arranged = output_arranged.flatten(end_dim=3).flatten(start_dim=1) + # => (N*C*L_out, 1) + output_arranged = output_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + # => (..., 1), dtype=(block_size, ) + + return input_arranged, output_arranged, norm_type, kernel_size_val + + +def _pow(x, norm, dtype): + pow_dtype = dtype if dtype != ntl.float16 else ntl.float32 + return ntl.cast(libdevice.pow(ntl.cast(x, pow_dtype), norm), dtype) + +def application(input, output, norm_type, kernel_size): + # input: (block_size, nearest_pow2) + # output: (block_size) + dtype = input.dtype + mask = input < 1e20 + cnt = ntl.sum(ntl.cast(mask, ntl.int32), axis=1) + input_masked = ntl.where(~mask, 0, input) + x_pow = _pow(input_masked, norm_type, dtype) + acc_sim = ntl.sum(x_pow, 1) / cnt * kernel_size + output = _pow(acc_sim, 1.0 / norm_type, dtype) + + +def premake(ndim, kernel_size, stride, ceil_mode=False, dtype=None, block_size=None): + arrangement_ = functools.partial( + arrangement, + kernel_size=kernel_size, + stride=stride, + block_size=block_size, + ceil_mode=ceil_mode, + ) + + tensors = ( + Tensor(ndim, dtype=dtype, other=float("inf")), # input + Tensor(ndim, dtype=dtype), # output + Tensor(0, dtype=dtype), # norm_type + Tensor(0, dtype=dtype, constexpr=True), # kernel_size + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/lp_pool2d.py b/src/ntops/kernels/lp_pool2d.py new file mode 100644 index 0000000..7bdbbf6 --- /dev/null +++ b/src/ntops/kernels/lp_pool2d.py @@ -0,0 +1,163 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed.language import libdevice +from ninetoothed import Tensor +from ninetoothed import Symbol + + +def _pow(x, norm, dtype): + pow_dtype = dtype if dtype != ntl.float16 else ntl.float32 + return ntl.cast(libdevice.pow(ntl.cast(x, pow_dtype), norm), dtype) + +def arrangement_ceil_mode(*tensors, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size, ceil_mode): + input, output, norm_type, kernel_size_flatted = tensors + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (N, C, H_in, W_in) output: (N, C, H_out, W_out) + # ref. example 里的 max_pool2d arrangement + + input_arranged = input.tile((1, 1, kernel_size_h, kernel_size_w), (1, 1, stride_h, stride_w), floor_mode=not ceil_mode) + # => (N, C, H_out, W_out), dtype=(1, 1, k_h, k_w) + input_arranged = input_arranged.ravel() + # => (N, C, H_out, W_out, 1, 1, k_h, k_w) + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + # => (N*C*H_out*W_out, k_h*k_w) + # k_h*k_w 的找到最近的 2 的倍数 + nearest_pow2 = 1 << (kernel_size_h * kernel_size_w - 1).bit_length() + input_arranged = input_arranged.tile((1, nearest_pow2)) + # => (..., k_h*k_w // nearest_pow2 = 1), dtype=(1, nearest_pow2) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + # => (..., 1), dtype=(nearest_pow2, ) + input_arranged = input_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1), dtype=(nearest_pow2, ) + input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1) + # => (..., 1), dtype=(block_size, nearest_pow2) + + output_arranged = output.tile((1, 1, 1, 1)) + # => (N, C, H_out, W_out), dtype=(1, 1, 1, 1) + output_arranged = output_arranged.ravel() + # => (N, C, H_out, W_out, 1, 1, 1, 1) + output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) + # => (N*C*H_out*W_out, 1) + output_arranged = output_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + # => (..., 1), dtype=(block_size, ) + + return input_arranged, output_arranged, norm_type, kernel_size_flatted + + + +def application_ceil_mode(input, output, norm_type, kernel_size_flatted): + # input: (block_size, nearest_pow2) arrangement 之后最外层被用于并行计算 + # output: (block_size, ) + # 这里 torch 实现与文档上的不一致,文档上描述的是 sum(windows^p)^(1/p) + # 实际上 torch 的实现是 mean(windows^p) * (kernel_size_h * kernel_size_w))^(1/p) + # 这在 strides=kernel_size 时的结果是一致的,但是在 strides!=kernel_size && ceil_mode=True 时会有差异 + # 主要体现在边界处理上, torch 的算法会放大边界处的值,因为边界处的窗口内有效元素个数少于 kernel_size_h * kernel_size_w + # 下面给出了两种不同的实现 + # 这是补 0 的实现 (要使用这种实现,请将input的默认值修改为 0) + # dtype = input.dtype + # x_pow = _pow(input, norm_type, dtype) + # acc = ntl.sum(x_pow, axis=0) + # output = _pow(acc, 1.0 / norm_type, dtype) + + # 为了通过测试,下面使用的是与 torch 实现一致的版本 + dtype = input.dtype + mask = input < 1e20 + cnt = ntl.sum(ntl.cast(mask, ntl.int32), axis=1) + input_masked = ntl.where(~mask, 0, input) + x_pow = _pow(input_masked, norm_type, dtype) + acc_sim = ntl.sum(x_pow, 1) / cnt * kernel_size_flatted + output = _pow(acc_sim, 1.0 / norm_type, dtype) + + +def premake_ceil_mode(ndim, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size=None, ceil_mode=False, dtype=None): + arrangement_ = functools.partial( + arrangement_ceil_mode, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_h=stride_h, + stride_w=stride_w, + block_size=block_size, + ceil_mode=ceil_mode, + ) + + tensors = ( + Tensor(ndim, dtype=dtype, other=float("inf")), # input + Tensor(ndim, dtype=dtype), # output + Tensor(0, dtype=dtype), # norm_type + Tensor(0, dtype=dtype), # kernel_size_flatted + ) + + return arrangement_, application_ceil_mode, tensors + + + +def arrangement(input, output, norm_type, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size, ceil_mode): + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (N, C, H_in, W_in) output: (N, C, H_out, W_out) + # ref. example 里的 max_pool2d arrangement + + input_arranged = input.tile((1, 1, kernel_size_h, kernel_size_w), (1, 1, stride_h, stride_w), floor_mode=not ceil_mode) + # => (N, C, H_out, W_out), dtype=(1, 1, k_h, k_w) + input_arranged = input_arranged.ravel() + # => (N, C, H_out, W_out, 1, 1, k_h, k_w) + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + # => (N*C*H_out*W_out, k_h*k_w) + # k_h*k_w 的找到最近的 2 的倍数 + nearest_pow2 = 1 << (kernel_size_h * kernel_size_w - 1).bit_length() + input_arranged = input_arranged.tile((1, nearest_pow2)) + # => (..., k_h*k_w // nearest_pow2 = 1), dtype=(1, nearest_pow2) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + # => (..., 1), dtype=(nearest_pow2, ) + input_arranged = input_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1), dtype=(nearest_pow2, ) + input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1) + # => (..., 1), dtype=(block_size, nearest_pow2) + + output_arranged = output.tile((1, 1, 1, 1)) + # => (N, C, H_out, W_out), dtype=(1, 1, 1, 1) + output_arranged = output_arranged.ravel() + # => (N, C, H_out, W_out, 1, 1, 1, 1) + output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) + # => (N*C*H_out*W_out, 1) + output_arranged = output_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + # => (..., 1), dtype=(block_size, ) + + return input_arranged, output_arranged, norm_type + +def application(input, output, norm_type): + # input: (block_size, nearest_pow2) + # output: (block_size, ) + dtype = input.dtype + x_pow = _pow(input, norm_type, dtype) + acc = ntl.sum(x_pow, axis=1) + output = _pow(acc, 1.0 / norm_type, dtype) + + +def premake(ndim, kernel_size_h, kernel_size_w, stride_h, stride_w, block_size=None, ceil_mode=False, dtype=None): + arrangement_ = functools.partial( + arrangement, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_h=stride_h, + stride_w=stride_w, + block_size=block_size, + ceil_mode=ceil_mode, + ) + + tensors = ( + Tensor(ndim, dtype=dtype, other=0), # input + Tensor(ndim, dtype=dtype), # output + Tensor(0, dtype=dtype), # norm_type + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/lp_pool3d.py b/src/ntops/kernels/lp_pool3d.py new file mode 100644 index 0000000..d872c01 --- /dev/null +++ b/src/ntops/kernels/lp_pool3d.py @@ -0,0 +1,160 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed.language import libdevice +from ninetoothed import Tensor +from ninetoothed import Symbol + + +def _pow(x, norm, dtype): + pow_dtype = dtype if dtype != ntl.float16 else ntl.float32 + return ntl.cast(libdevice.pow(ntl.cast(x, pow_dtype), norm), dtype) + +def arrangement_ceil_mode( + *tensors, + kernel_size_d, + kernel_size_h, + kernel_size_w, + stride_d, + stride_h, + stride_w, + block_size, + ceil_mode, +): + """ceil_mode 下的 arrangement, 需要额外传入 kernel_size_flatted""" + input, output, norm_type, kernel_size_flatted = tensors + input_arranged, output_arranged, norm_type = arrangement( + input, + output, + norm_type, + kernel_size_d, + kernel_size_h, + kernel_size_w, + stride_d, + stride_h, + stride_w, + block_size, + ceil_mode, + ) + return input_arranged, output_arranged, norm_type, kernel_size_flatted + + + +def application_ceil_mode(input, output, norm_type, kernel_size_flatted): + # input: (block_size, nearest_pow2) + # output: (block_size, ) + # INFO: 下面的内容同时适用于 lp_pool2d 和 lp_pool3d + # 这里 torch 实现与文档上的不一致,文档上描述的是 sum(windows^p)^(1/p) + # 实际上 torch 的实现是 mean(windows^p) * (kernel_size_h * kernel_size_w))^(1/p) + # 这在 strides=kernel_size 时的结果是一致的,但是在 strides!=kernel_size && ceil_mode=True 时会有差异 + # 主要体现在边界处理上, torch 的算法会放大边界处的值,因为边界处的窗口内有效元素个数少于 kernel_size_h * kernel_size_w + # 下面给出了两种不同的实现 + # 这是补 0 的实现 (要使用这种实现,请将input的默认值修改为 0) + # dtype = input.dtype + # x_pow = _pow(input, norm_type, dtype) + # acc = ntl.sum(x_pow, axis=0) + # output = _pow(acc, 1.0 / norm_type, dtype) + + # 我把 ceil_mode 和普通的实现区分开来了 + # 为了通过测试,下面使用的是与 torch 实现一致的版本 + dtype = input.dtype + mask = input < 1e20 + cnt = ntl.sum(ntl.cast(mask, ntl.int32), axis=1) + input_masked = ntl.where(~mask, 0, input) + x_pow = _pow(input_masked, norm_type, dtype) + acc_sim = ntl.sum(x_pow, 1) / cnt * kernel_size_flatted + output = _pow(acc_sim, 1.0 / norm_type, dtype) + + +def premake_ceil_mode(ndim, kernel_size_d, kernel_size_h, kernel_size_w, stride_d, stride_h, stride_w, block_size=None, ceil_mode=False, dtype=None): + arrangement_ = functools.partial( + arrangement_ceil_mode, + kernel_size_d=kernel_size_d, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_d=stride_d, + stride_h=stride_h, + stride_w=stride_w, + block_size=block_size, + ceil_mode=ceil_mode, + ) + + tensors = ( + Tensor(ndim, dtype=dtype, other=float("inf")), # input + Tensor(ndim, dtype=dtype), # output + Tensor(0, dtype=dtype), # norm_type + Tensor(0, dtype=dtype), # kernel_size_flatted + ) + + return arrangement_, application_ceil_mode, tensors + + + +def arrangement(input, output, norm_type, kernel_size_d, kernel_size_h, kernel_size_w, stride_d, stride_h, stride_w, block_size, ceil_mode): + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (N, C, D_in, H_in, W_in) output: (N, C, D_out, H_out, W_out) + # ref. example 里的 max_pool2d arrangement + + input_arranged = input.tile((1, 1, kernel_size_d, kernel_size_h, kernel_size_w), (1, 1, stride_d, stride_h, stride_w), floor_mode=not ceil_mode) + # => (N, C, D_out, H_out, W_out), dtype=(1, 1, k_d, k_h, k_w) + input_arranged = input_arranged.ravel() + # => (N, C, D_out, H_out, W_out, 1, 1, k_d, k_h, k_w) + input_arranged = input_arranged.flatten(end_dim=5).flatten(start_dim=1) + # => (N*C*D_out*H_out*W_out, k_d*k_h*k_w) + + # k_d*k_h*k_w 的找到最近的 2 的倍数 + nearest_pow2 = 1 << (kernel_size_d * kernel_size_h * kernel_size_w - 1).bit_length() + input_arranged = input_arranged.tile((1, nearest_pow2)) + # => (..., k_d*k_h*k_w // nearest_pow2 = 1), dtype=(1, nearest_pow2) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + # => (..., 1), dtype=(nearest_pow2, ) + input_arranged = input_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1), dtype=(nearest_pow2, ) + input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1) + # => (..., 1), dtype=(block_size, nearest_pow2) + + output_arranged = output.tile((1, 1, 1, 1, 1)) + # => (N, C, D_out, H_out, W_out), dtype=(1, 1, 1, 1, 1) + output_arranged = output_arranged.ravel() + # => (N, C, D_out, H_out, W_out, 1, 1, 1, 1) + output_arranged = output_arranged.flatten(end_dim=5).flatten(start_dim=1) + # => (N*C*D_out*H_out*W_out, 1) + output_arranged = output_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + # => (..., 1), dtype=(block_size, ) + + return input_arranged, output_arranged, norm_type + +def application(input, output, norm_type): + # input: (block_size, nearest_pow2) + # output: (block_size, ) + dtype = input.dtype + x_pow = _pow(input, norm_type, dtype) + acc = ntl.sum(x_pow, axis=1) + output = _pow(acc, 1.0 / norm_type, dtype) + + +def premake(ndim, kernel_size_d, kernel_size_h, kernel_size_w, stride_d, stride_h, stride_w, block_size=None, ceil_mode=False, dtype=None): + arrangement_ = functools.partial( + arrangement, + kernel_size_d=kernel_size_d, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_d=stride_d, + stride_h=stride_h, + stride_w=stride_w, + block_size=block_size, + ceil_mode=ceil_mode, + ) + + tensors = ( + Tensor(ndim, dtype=dtype, other=0), # input + Tensor(ndim, dtype=dtype), # output + Tensor(0, dtype=dtype), # norm_type + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/max.py b/src/ntops/kernels/max.py new file mode 100644 index 0000000..84c43ae --- /dev/null +++ b/src/ntops/kernels/max.py @@ -0,0 +1,61 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +import math +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, output, output_idx): + # input: (C // block_size, ) dtype: (block_size, ) + # output: (C // block_size, ) dtype: (block_size, ) + dtype = output.dtype.dtype + prev_max = ntl.cast(float("-inf"), dtype) + global_idx = -1 + offset = input.dtype.shape[0] + + for i in range(input.shape[0]): + curr_idx = ntl.argmax(input[i], 0) + (i * offset) + curr_max = ntl.cast(ntl.maximum(prev_max, ntl.max(input[i])), dtype) + global_idx = curr_idx if curr_max > prev_max else global_idx + prev_max = curr_max + + output[0] = prev_max + output_idx[0] = global_idx + + +def premake(ndim, dim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor( + ndim, dtype=dtype, other=float("-inf"), shape_options={"constexpr": True} + ), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=ninetoothed.int32) + ) + + return arrangement_, application, tensors + +def arrangement_all_elements(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input = input.flatten().tile((block_size,)) + output = output.tile((1,)) + return input, output + +def application_all_elements(input, output): + output[0] = ntl.max(input, 0) + +def premake_all_elements(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement_all_elements, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, other=float("-inf"), shape_options={"constexpr": True}), + Tensor(1, dtype=dtype), + ) + + return arrangement_, application_all_elements, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..95ac8d2 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -36,6 +36,11 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.logsumexp import logsumexp +from ntops.torch.lp_pool1d import lp_pool1d +from ntops.torch.lp_pool2d import lp_pool2d +from ntops.torch.lp_pool3d import lp_pool3d +from ntops.torch.max import max __all__ = [ "abs", @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "logsumexp", + "lp_pool1d", + "lp_pool2d", + "lp_pool3d", + "max", ] diff --git a/src/ntops/torch/logsumexp.py b/src/ntops/torch/logsumexp.py new file mode 100644 index 0000000..5fde479 --- /dev/null +++ b/src/ntops/torch/logsumexp.py @@ -0,0 +1,27 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def logsumexp(input, dim, keepdim=False, *, out=None): + tensor_dtype = out.dtype if out is not None else input.dtype + + output_shape = list(input.shape) + output_shape[dim] = 1 + + temp_out = torch.empty(output_shape, dtype=tensor_dtype, device=input.device) + + block_size = 256 + kernel = _cached_make(ntops.kernels.logsumexp.premake, input.ndim, dim, block_size) + kernel(input, temp_out) + + if not keepdim: + del output_shape[dim] + temp_out = temp_out.view(output_shape) + + if out is not None: + out.copy_(temp_out) + return out + + return temp_out diff --git a/src/ntops/torch/lp_pool1d.py b/src/ntops/torch/lp_pool1d.py new file mode 100644 index 0000000..5635013 --- /dev/null +++ b/src/ntops/torch/lp_pool1d.py @@ -0,0 +1,57 @@ +import math +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): + """ + 一维 Lp 池化 + + 参数: + input: (N, C, L_in) 输入张量 + norm_type: Lp 范数的 p 值(1.0, 2.0, 等) + kernel_size: 窗口大小 + stride: 步长,默认等于 kernel_size + ceil_mode: 是否使用 ceil 模式计算输出长度 + + 返回: + output: (N, C, L_out) 输出张量 + """ + + assert input.ndim == 3 or input.ndim == 2, ( + "Input tensor must be 3-dimensional (N, C, L_in) or (C, L_in)" + ) + if input.ndim == 2: + input = input.view(1, input.shape[0], input.shape[1]) + + if stride is None: + stride = kernel_size + + L_in = input.shape[-1] + + # 计算输出长度 + if ceil_mode: + L_out = math.ceil((L_in - kernel_size + stride) / stride) + else: + L_out = math.floor((L_in - kernel_size + stride) / stride) + + output_shape = (input.shape[0], input.shape[1], L_out) + + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + block_size = 1024 + kernel = _cached_make( + ntops.kernels.lp_pool1d.premake, + input.ndim, + kernel_size, + stride, + ceil_mode=ceil_mode, + dtype=input.dtype, + block_size=block_size + ) + + kernel(input, output, norm_type, kernel_size) + + return output diff --git a/src/ntops/torch/lp_pool2d.py b/src/ntops/torch/lp_pool2d.py new file mode 100644 index 0000000..2bd8c5d --- /dev/null +++ b/src/ntops/torch/lp_pool2d.py @@ -0,0 +1,65 @@ +import math +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def lp_pool2d(input, norm_type, kernel_size: int | tuple[int, int], stride: None | int | tuple[int, int] = None, ceil_mode=False): + assert input.ndim == 4 or input.ndim == 3, "Input tensor must be 4-dimensional (N, C, H_in, W_in) or 3-dimensional (C, H_in, W_in)" + + if input.ndim == 3: + input = input.unsqueeze(0) # 添加 batch 维度 + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride) + + if stride is None: + stride = kernel_size + + # 计算输出长度 + H_in, W_in = input.shape[-2], input.shape[-1] + if ceil_mode: + H_out = math.ceil((H_in - kernel_size[0] + stride[0]) / stride[0]) + W_out = math.ceil((W_in - kernel_size[1] + stride[1]) / stride[1]) + else: + H_out = math.floor((H_in - kernel_size[0] + stride[0]) / stride[0]) + W_out = math.floor((W_in - kernel_size[1] + stride[1]) / stride[1]) + + output_shape = (input.shape[0], input.shape[1], H_out, W_out) + + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + block_size = 1024 + if ceil_mode: + kernel = _cached_make( + ntops.kernels.lp_pool2d.premake_ceil_mode, + input.ndim, + kernel_size[0], + kernel_size[1], + stride[0], + stride[1], + block_size=block_size, + ceil_mode=ceil_mode, + dtype=input.dtype + ) + kernel(input, output, norm_type, kernel_size[0] * kernel_size[1]) + else: + kernel = _cached_make( + ntops.kernels.lp_pool2d.premake, + input.ndim, + kernel_size[0], + kernel_size[1], + stride[0], + stride[1], + block_size=block_size, + ceil_mode=ceil_mode, + dtype=input.dtype + ) + kernel(input, output, norm_type) + + + return output diff --git a/src/ntops/torch/lp_pool3d.py b/src/ntops/torch/lp_pool3d.py new file mode 100644 index 0000000..3438e24 --- /dev/null +++ b/src/ntops/torch/lp_pool3d.py @@ -0,0 +1,71 @@ +import math +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def lp_pool3d(input, norm_type, kernel_size: int | tuple[int, int, int], stride: None | int | tuple[int, int, int] = None, ceil_mode=False): + assert input.ndim == 5 or input.ndim == 4, "Input tensor must be 4-dimensional (N, C, D_in, H_in, W_in) or 3-dimensional (C, D_in, H_in, W_in)" + + if input.ndim == 4: + input = input.unsqueeze(0) # 添加 batch 维度 + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride, stride) + + if stride is None: + stride = kernel_size + + # 计算输出长度 + N, C, D_in, H_in, W_in = input.shape + if ceil_mode: + D_out = math.ceil((D_in - kernel_size[0] + stride[0]) / stride[0]) + H_out = math.ceil((H_in - kernel_size[1] + stride[1]) / stride[1]) + W_out = math.ceil((W_in - kernel_size[2] + stride[2]) / stride[2]) + else: + D_out = math.floor((D_in - kernel_size[0] + stride[0]) / stride[0]) + H_out = math.floor((H_in - kernel_size[1] + stride[1]) / stride[1]) + W_out = math.floor((W_in - kernel_size[2] + stride[2]) / stride[2]) + + output_shape = (N, C, D_out, H_out, W_out) + + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + block_size = 256 + if ceil_mode: + kernel = _cached_make( + ntops.kernels.lp_pool3d.premake_ceil_mode, + input.ndim, + kernel_size[0], + kernel_size[1], + kernel_size[2], + stride[0], + stride[1], + stride[2], + block_size=block_size, + ceil_mode=ceil_mode, + dtype=input.dtype, + ) + kernel(input, output, norm_type, kernel_size[0] * kernel_size[1] * kernel_size[2]) + else: + kernel = _cached_make( + ntops.kernels.lp_pool3d.premake, + input.ndim, + kernel_size[0], + kernel_size[1], + kernel_size[2], + stride[0], + stride[1], + stride[2], + block_size=block_size, + ceil_mode=ceil_mode, + dtype=input.dtype + ) + kernel(input, output, norm_type) + + + return output diff --git a/src/ntops/torch/max.py b/src/ntops/torch/max.py new file mode 100644 index 0000000..822d867 --- /dev/null +++ b/src/ntops/torch/max.py @@ -0,0 +1,49 @@ +import torch + +import ntops +import ninetoothed +from ntops.torch.utils import _cached_make +import builtins +import math + +def max(input, dim: int | None = None, keepdim=False, *, out=None): + if dim is None: + current = input + + # 递归地应用 max kernel 直到只剩一个元素 + block_size = 1024 + while current.numel() > 1: + output_shape = (math.ceil(current.numel() / block_size),) + output = torch.empty(output_shape, dtype=current.dtype, device=current.device) + kernel = _cached_make(ntops.kernels.max.premake_all_elements, current.ndim, current.dtype, block_size) + kernel(current, output) + current = output + + result = current.view(()) + + if out is not None: + out.copy_(result) + return out + + return result + else: + output_shape = list(input.shape) + output_shape[dim] = 1 + + temp_out = torch.empty(output_shape, dtype=input.dtype, device=input.device) + temp_out_idx = torch.empty(output_shape, dtype=torch.int64, device=input.device) + + block_size = 1024 + kernel = _cached_make(ntops.kernels.max.premake, input.ndim, dim, block_size) + kernel(input, temp_out, temp_out_idx) + + if not keepdim: + del output_shape[dim] + temp_out = temp_out.view(output_shape) + temp_out_idx = temp_out_idx.view(output_shape) + + if out is not None: + out.copy_(temp_out) + return out, temp_out_idx + + return temp_out, temp_out_idx diff --git a/tests/test_logsumexp.py b/tests/test_logsumexp.py new file mode 100644 index 0000000..9f688e8 --- /dev/null +++ b/tests/test_logsumexp.py @@ -0,0 +1,296 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("keepdim", (False, True)) +def test_logsumexp(shape, dtype, device, rtol, atol, keepdim): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(0, input_tensor.ndim - 1) + + if random.choice((True, False)): + dim = dim - input_tensor.ndim + + reference_output = torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + ntops_output = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + assert torch.allclose(ntops_output, reference_output, rtol=rtol, atol=atol) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("keepdim", [False, True]) +def test_logsumexp_with_strided_output(dtype, keepdim): + """测试 logsumexp 使用非连续(strided)输出张量的情况""" + device = "cuda" + + # 创建输入张量 + input_tensor = torch.randn(4, 5, 6, dtype=dtype, device=device) + dim = 2 + + # 计算参考输出 + reference_output = torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + # 创建一个具有不同 strides 的输出张量 + if keepdim: + # 创建一个更大的张量,然后切片得到具有非标准 strides 的子张量 + large_tensor = torch.empty(4, 5, 3, dtype=dtype, device=device) + out = large_tensor[:, :, :1] # shape (4, 5, 1) 但 strides 为 (15, 3, 1) 而不是标准的 (5, 1, 1) + else: + # 创建一个更大的张量,然后切片得到具有非标准 strides 的子张量 + large_tensor = torch.empty(4, 5, 2, dtype=dtype, device=device) + out = large_tensor[:, :, 0] # shape (4, 5) 但 strides 为 (10, 2) 而不是标准的 (5, 1) + + # 使用 ntops 的 logsumexp,传入 strided 输出 + result = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim, out=out) + + # 验证结果 + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + assert torch.allclose(result, reference_output, rtol=rtol, atol=atol) + assert result is out # 确保返回的是传入的 out 张量 + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_logsumexp_with_transposed_output(dtype): + """测试 logsumexp 使用转置(非连续)输出张量的情况""" + device = "cuda" + + # 创建输入张量 + input_tensor = torch.randn(3, 4, 5, dtype=dtype, device=device) + dim = 1 + keepdim = True + + # 计算参考输出 + reference_output = torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + # 创建一个转置的输出张量(非连续) + out_base = torch.empty(1, 3, 5, dtype=dtype, device=device) + out = out_base.transpose(0, 1) # shape (3, 1, 5),但内存布局非连续 + + # 使用 ntops 的 logsumexp + result = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim, out=out) + + # 验证结果 + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + assert torch.allclose(result, reference_output, rtol=rtol, atol=atol) + assert result is out + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_logsumexp_with_strided_slice_output(dtype): + """测试 logsumexp 使用步进切片(strided slice)输出张量的情况""" + device = "cuda" + + # 创建输入张量 + input_tensor = torch.randn(2, 8, 6, dtype=dtype, device=device) + dim = 2 + keepdim = True + + # 计算参考输出 + reference_output = torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + # 创建一个更大的张量,使用步进切片得到非连续的子张量 + large_tensor = torch.empty(2, 8, 4, dtype=dtype, device=device) + out = large_tensor[:, ::2, :1] # shape (2, 4, 1),strides 非标准 + + # 调整输入以匹配输出的第二维 + input_tensor_adjusted = input_tensor[:, ::2, :] # shape (2, 4, 6) + reference_output_adjusted = torch.logsumexp(input_tensor_adjusted, dim=dim, keepdim=keepdim) + + # 使用 ntops 的 logsumexp + result = ntops.torch.logsumexp(input_tensor_adjusted, dim=dim, keepdim=keepdim, out=out) + + # 验证结果 + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + assert torch.allclose(result, reference_output_adjusted, rtol=rtol, atol=atol) + assert result is out + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("keepdim", [False, True]) +@pytest.mark.parametrize("dim", [0, 1, 2]) +def test_logsumexp_with_contiguous_out(dtype, keepdim, dim): + """测试 logsumexp 使用正常的连续(contiguous)输出张量""" + device = "cuda" + + # 创建输入张量 + input_tensor = torch.randn(4, 5, 6, dtype=dtype, device=device) + + # 计算参考输出 + reference_output = torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + # 创建一个连续的输出张量 + out = torch.empty_like(reference_output) + + # 使用 ntops 的 logsumexp,传入连续输出 + result = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim, out=out) + + # 验证结果 + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + assert torch.allclose(result, reference_output, rtol=rtol, atol=atol) + assert result is out # 确保返回的是传入的 out 张量 + assert out.is_contiguous() # 确保输出是连续的 + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("keepdim", [False, True]) +def test_logsumexp_with_out_different_strides_dim0(dtype, keepdim): + """测试 logsumexp 在 dim=0 时使用 strided 输出张量""" + device = "cuda" + + # 创建输入张量 + input_tensor = torch.randn(6, 4, 5, dtype=dtype, device=device) + dim = 0 + + # 计算参考输出 + reference_output = torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + # 创建一个具有不同 strides 的输出张量 + if keepdim: + large_tensor = torch.empty(3, 4, 5, dtype=dtype, device=device) + out = large_tensor[:1, :, :] # shape (1, 4, 5) 但 strides 为 (20, 5, 1) 而不是标准的 (20, 5, 1) + else: + large_tensor = torch.empty(4, 5, 2, dtype=dtype, device=device) + out = large_tensor[:, :, 0] # shape (4, 5) 但 strides 为 (10, 2) 而不是标准的 (5, 1) + + # 使用 ntops 的 logsumexp + result = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim, out=out) + + # 验证结果 + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + assert torch.allclose(result, reference_output, rtol=rtol, atol=atol) + assert result is out + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("keepdim", [False, True]) +def test_logsumexp_with_out_different_strides_dim1(dtype, keepdim): + """测试 logsumexp 在 dim=1 时使用 strided 输出张量""" + device = "cuda" + + # 创建输入张量 + input_tensor = torch.randn(3, 8, 5, dtype=dtype, device=device) + dim = 1 + + # 计算参考输出 + reference_output = torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + # 创建一个具有不同 strides 的输出张量 + if keepdim: + large_tensor = torch.empty(3, 2, 5, dtype=dtype, device=device) + out = large_tensor[:, :1, :] # shape (3, 1, 5) 但 strides 为 (10, 5, 1) 而不是标准的 (5, 5, 1) + else: + large_tensor = torch.empty(3, 5, 3, dtype=dtype, device=device) + out = large_tensor[:, :, 0] # shape (3, 5) 但 strides 为 (15, 3) 而不是标准的 (5, 1) + + # 使用 ntops 的 logsumexp + result = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim, out=out) + + # 验证结果 + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + assert torch.allclose(result, reference_output, rtol=rtol, atol=atol) + assert result is out + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_logsumexp_with_out_permuted(dtype): + """测试 logsumexp 使用 permute 后的输出张量(非连续)""" + device = "cuda" + + # 创建输入张量 + input_tensor = torch.randn(3, 4, 5, 6, dtype=dtype, device=device) + dim = 2 + keepdim = True + + # 计算参考输出 + reference_output = torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + # 创建一个 permute 后的输出张量(非连续) + # 使用 permute(3, 2, 1, 0) 来确保输出是非连续的 + out_base = torch.empty(6, 1, 4, 3, dtype=dtype, device=device) + out = out_base.permute(3, 2, 1, 0) # shape (3, 4, 1, 6),内存布局非连续 + + assert not out.is_contiguous(), "out should be non-contiguous before calling logsumexp" + + # 使用 ntops 的 logsumexp + result = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim, out=out) + + # 验证结果 + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + assert torch.allclose(result, reference_output, rtol=rtol, atol=atol) + assert result is out + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_logsumexp_with_out_multiple_strides(dtype): + """测试 logsumexp 使用多个维度都有非标准 strides 的输出张量""" + device = "cuda" + + # 创建输入张量 + input_tensor = torch.randn(2, 6, 8, dtype=dtype, device=device) + dim = 2 + keepdim = True + + # 计算参考输出 + reference_output = torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + # 创建一个多个维度都有非标准 strides 的输出张量 + # 使用步进切片在多个维度上创建非连续张量 + large_tensor = torch.empty(4, 12, 3, dtype=dtype, device=device) + out = large_tensor[::2, ::2, :1] # shape (2, 6, 1),所有维度的 strides 都非标准 + + # 使用 ntops 的 logsumexp + result = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim, out=out) + + # 验证结果 + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-6 + assert torch.allclose(result, reference_output, rtol=rtol, atol=atol) + assert result is out + assert not out.is_contiguous() + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("keepdim", [False, True]) +def test_logsumexp_out_vs_no_out(dtype, keepdim): + """测试使用 out 参数和不使用 out 参数的结果一致性""" + device = "cuda" + + # 创建输入张量 + input_tensor = torch.randn(4, 5, 6, dtype=dtype, device=device) + dim = 1 + + # 不使用 out 参数 + result_no_out = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim) + + # 使用连续的 out 参数 + out_contiguous = torch.empty_like(result_no_out) + result_with_out = ntops.torch.logsumexp(input_tensor, dim=dim, keepdim=keepdim, out=out_contiguous) + + # 验证结果一致 + assert torch.allclose(result_no_out, result_with_out, rtol=1e-6, atol=1e-6) + assert result_with_out is out_contiguous + + diff --git a/tests/test_lp_pool1d.py b/tests/test_lp_pool1d.py new file mode 100644 index 0000000..f76207f --- /dev/null +++ b/tests/test_lp_pool1d.py @@ -0,0 +1,46 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("norm_type", (1.0, 2.0, 3.0)) +@pytest.mark.parametrize("ceil_mode", (False, True)) +@pytest.mark.parametrize("use_stride", (False, True)) +def test_lp_pool1d(norm_type, ceil_mode, use_stride): + device = "cuda" + dtype = torch.float32 + + batch = random.randint(1, 4) + channels = random.randint(1, 4) + length = random.randint(4, 32) + + kernel_size = random.randint(1, min(5, length)) + + if use_stride: + stride = random.randint(1, max(1, kernel_size)) + else: + stride = None + + input_tensor = torch.randn((batch, channels, length), device=device, dtype=dtype) + + ntops_output = ntops.torch.lp_pool1d( + input_tensor, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + reference_output = torch.nn.functional.lp_pool1d( + input_tensor, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + + assert torch.allclose(ntops_output, reference_output, atol=1e-3, rtol=1e-3, equal_nan=True) diff --git a/tests/test_lp_pool2d.py b/tests/test_lp_pool2d.py new file mode 100644 index 0000000..d341d53 --- /dev/null +++ b/tests/test_lp_pool2d.py @@ -0,0 +1,59 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("norm_type", (1.0, 2.0, 3.0)) +@pytest.mark.parametrize("ceil_mode", (False, True)) +@pytest.mark.parametrize("use_tuple_kernel", (False, True)) +@pytest.mark.parametrize("use_stride", (False, True)) +def test_lp_pool2d(norm_type, ceil_mode, use_tuple_kernel, use_stride): + device = "cuda" + dtype = torch.float32 + + batch = random.randint(1, 3) + channels = random.randint(1, 4) + height = random.randint(4, 24) + width = random.randint(4, 24) + + if use_tuple_kernel: + k_h = random.randint(1, min(5, height)) + k_w = random.randint(1, min(5, width)) + kernel_size = (k_h, k_w) + else: + k = random.randint(1, min(5, height, width)) + kernel_size = k + + if use_stride: + if use_tuple_kernel and isinstance(kernel_size, tuple): + s_h = random.randint(1, max(1, kernel_size[0])) + s_w = random.randint(1, max(1, kernel_size[1])) + stride = (s_h, s_w) + else: + stride = random.randint(1, max(1, kernel_size if isinstance(kernel_size, int) else min(kernel_size))) + else: + stride = None + + input_tensor = torch.randn((batch, channels, height, width), device=device, dtype=dtype) + + ntops_output = ntops.torch.lp_pool2d( + input_tensor, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + reference_output = torch.nn.functional.lp_pool2d( + input_tensor, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + + assert torch.allclose(ntops_output, reference_output, atol=1e-3, rtol=1e-3, equal_nan=True) diff --git a/tests/test_lp_pool3d.py b/tests/test_lp_pool3d.py new file mode 100644 index 0000000..3862177 --- /dev/null +++ b/tests/test_lp_pool3d.py @@ -0,0 +1,63 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("norm_type", (1.0, 2.0, 3.0)) +@pytest.mark.parametrize("ceil_mode", (False, True)) +@pytest.mark.parametrize("use_tuple_kernel", (False, True)) +@pytest.mark.parametrize("use_stride", (False, True)) +def test_lp_pool3d(norm_type, ceil_mode, use_tuple_kernel, use_stride): + device = "cuda" + dtype = torch.float32 + + batch = random.randint(1, 2) + channels = random.randint(1, 3) + depth = random.randint(4, 16) + height = random.randint(4, 16) + width = random.randint(4, 16) + + if use_tuple_kernel: + k_d = random.randint(1, min(4, depth)) + k_h = random.randint(1, min(4, height)) + k_w = random.randint(1, min(4, width)) + kernel_size = (k_d, k_h, k_w) + else: + k = random.randint(1, min(4, depth, height, width)) + kernel_size = k + + if use_stride: + if use_tuple_kernel and isinstance(kernel_size, tuple): + s_d = random.randint(1, max(1, kernel_size[0])) + s_h = random.randint(1, max(1, kernel_size[1])) + s_w = random.randint(1, max(1, kernel_size[2])) + stride = (s_d, s_h, s_w) + else: + base = kernel_size if isinstance(kernel_size, int) else min(kernel_size) + stride = random.randint(1, max(1, base)) + else: + stride = None + + input_tensor = torch.randn((batch, channels, depth, height, width), device=device, dtype=dtype) + + ntops_output = ntops.torch.lp_pool3d( + input_tensor, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + reference_output = torch.nn.functional.lp_pool3d( + input_tensor, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + + assert torch.allclose(ntops_output, reference_output, atol=1e-3, rtol=1e-3, equal_nan=True) diff --git a/tests/test_max.py b/tests/test_max.py new file mode 100644 index 0000000..35b6310 --- /dev/null +++ b/tests/test_max.py @@ -0,0 +1,37 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("keepdim", (False, True)) +def test_max_dim(shape, dtype, device, rtol, atol, keepdim): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(0, input_tensor.ndim - 1) + + if random.choice((True, False)): + dim = dim - input_tensor.ndim + + ntops_values, ntops_indices = ntops.torch.max(input_tensor, dim=dim, keepdim=keepdim) + reference_values, reference_indices = torch.max(input_tensor, dim=dim, keepdim=keepdim) + + assert torch.allclose(ntops_values, reference_values, rtol=rtol, atol=atol) + assert torch.equal(ntops_indices, reference_indices) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_max_global(shape, dtype, device, rtol, atol): + """测试全局 max (dim=None),返回标量最大值""" + input_tensor = torch.randn(shape, dtype=dtype, device=device) + + ntops_value = ntops.torch.max(input_tensor) + reference_value = torch.max(input_tensor) + + assert torch.allclose(ntops_value, reference_value, rtol=rtol, atol=atol)