diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..848045f --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,30 @@ +# Repository Guidelines + +## 项目结构与模块组织 +- `src/ntops/` 是包入口;`kernels/` 放 NineToothed kernel 定义(每个算子一个文件),`torch/` 是 PyTorch 侧封装与调用。 +- `tests/` 使用 pytest,单算子测试如 `tests/test_add.py`,共享夹具在 `tests/conftest.py` 与 `tests/utils.py`。 +- 根目录 `README.md` 为项目概览,`doc.md`/`topk.md` 记录算子笔记与设计说明。 + +## NineToothed 依赖与接口要点 +- NineToothed 是基于 Triton 的 DSL,核心范式为 TOM:`arrangement`(布局变换)+ `application`(算子表达)+ `tensors`(符号张量定义)。 +- `ninetoothed.Tensor` 表达符号张量,`tile/expand/permute` 等用于构建层级布局;`ninetoothed.block_size()` 提供可调块大小符号。 +- `ninetoothed.language as ntl` 提供算子接口(如 `ntl.zeros`、`ntl.dot`);`ninetoothed.make(...)` 负责 JIT 生成可调用 kernel。 +- 在 ntops 中,每个 kernel 提供 `premake(...) -> (arrangement, application, tensors)`;`src/ntops/torch/utils.py` 中 `_cached_make` 统一调用 `ninetoothed.make`。 + +## 构建、测试与开发命令 +- `python -m pip install -e .` 以 editable 方式安装;`python -m pip install -e ".[testing]"` 安装测试依赖。 +- `python -m pytest` 运行全量测试;`python -m pytest tests/test_mm.py` 运行单文件。 +- `python -m ruff check .` 运行 lint(ruff)。 + +## 编码风格与命名约定 +- Python 4 空格缩进,snake_case;算子文件与入口函数保持同名(如 `kernels/gelu.py`)。 +- `application` 中仅写符号计算,避免引入真实张量运算或框架 side-effect。 +- `pyproject.toml` 已启用 ruff 的错误/导入规则,保持导入顺序一致。 + +## 测试指南 +- 使用 pytest;新增算子测试命名 `test_.py`,优先使用 `pytest.mark.parametrize`。 +- 精度容差与已有测试一致(通常为 `rtol, atol`),必要时在测试内说明原因。 + +## 提交与 PR 规范 +- 提交信息简洁、动词开头(示例:"Add matmul operator"、"Refactor tests")。 +- PR 需描述改动、附测试结果,并关联相关 issue 或设计文档(如 `doc.md`)。 diff --git a/doc.md b/doc.md new file mode 100644 index 0000000..640af65 --- /dev/null +++ b/doc.md @@ -0,0 +1,135 @@ +目前,九齿(NineToothed)是一门基于 Triton 的领域特定语言(DSL),旨在进一步简化高性能计算内核的开发。它通过引入面向张量的元编程(tensor-oriented metaprogramming),抽象掉了指针算术运算和内存访问等底层细节,能够降低并行编程的门槛。九齿能够让开发者使用少量简洁的代码实现较高性能的计算内核,并且可以提高代码的可读性和可维护性。 +核心概念 +符号 +符号这一概念,与这篇 SymPy 教程当中描写的类似。符号并不存储实际的数值,只存储符号或是符号表达式,所以允许进行一些符号化的数学运算。在九齿中,我们可以使用 Symbol 来创建一个符号。例如,在下面的代码里,我们先是创建了名为 BLOCK_SIZE_M 和 BLOCK_SIZE_N 的两个符号,之后对它们进行了乘法操作: +>>> from ninetoothed import Symbol +>>> BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M") +>>> BLOCK_SIZE_M +BLOCK_SIZE_M +>>> BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N") +>>> BLOCK_SIZE_N +BLOCK_SIZE_N +>>> BLOCK_SIZE_M * BLOCK_SIZE_N +BLOCK_SIZE_M * BLOCK_SIZE_N +符号张量 +张量是深度学习领域的基础概念之一,如果您对张量尚不熟悉,可以参考这篇 PyTorch 的教程。九齿当中的张量,与 PyTorch 中的类似,但是并不存储实际数据,仅在 shape、strides 等成员变量中存储符号表达式。在九齿中,我们可以使用 Tensor 来创建一个张量。如下方代码所示,Tensor(2) 表示构造一个二维张量,也就是一个矩阵,而它的 shape 成员里所存储的,也都是符号,并非具体的数值: +>>> from ninetoothed import Tensor +>>> x = Tensor(2) +>>> x.shape +(ninetoothed_tensor_0_size_0, ninetoothed_tensor_0_size_1) +面向张量的元编程 +得益于符号张量,我们可以对九齿中的张量进行一些编译期操作,这样的操作被称为元操作,如 tile、expand、squeeze、permute 等。例如,在这一段代码中,我们对 x 进行了 tile 操作,即将 x 分为形状为 (BLOCK_SIZE_M, BLOCK_SIZE_N) 的块: +>>> x_tiled = x.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) +>>> x_tiled.shape +((ninetoothed_tensor_0_size_0 - (BLOCK_SIZE_M - 1) - 1 + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + 1, (ninetoothed_tensor_0_size_1 - (BLOCK_SIZE_N - 1) - 1 + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + 1) +>>> x_tiled.dtype.shape +(BLOCK_SIZE_M, BLOCK_SIZE_N) +我们注意到,x_tiled 的 dtype 也有 shape 这一成员变量。这是由于,九齿当中的张量是可以嵌套的,即一个张量的元素也可以是一个张量。也就是说,在 tile 的过程中,我们创建了一个双层的张量,其中外层张量的每一个元素,都是一个内层张量。为了方便理解,我们可以使用如下的数值化示例来进行说明: +>>> BLOCK_SIZE_M = 2 +>>> BLOCK_SIZE_N = 2 +>>> x = Tensor(shape=(4, 8)) +>>> x_tiled = x.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) +>>> x_tiled.shape +(2, 4) +>>> x_tiled.dtype.shape +(2, 2) +就像下图所示的那样,我们将一个形状为 (4, 8) 的张量 x 分成了形状为 (2, 2) 的块(内层张量),总共分成了 (2, 4) 个这样的张量(外层张量): +[图片] +排布与应用范式 +九齿引入了排布与应用(arrange-and-apply)范式,其中排布指的是如何使用元操作,对张量进行分块等排布,使得各参数张量的分块能够对齐;应用则指的是如何应用排布后的分块来完成整个算法。或者说,排布后所产生的多层张量的最外层,将会被用于并行程序的启动,而每一个并行程序所实际使用的,其实是内层张量。 +接下来,让我们先通过一个简单的向量加法的例子,来理解这一范式: +import ninetoothed +from ninetoothed import Symbol, Tensor + + +def arrangement(lhs, rhs, output): + BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True) + + return lhs.tile((BLOCK_SIZE,)), rhs.tile((BLOCK_SIZE,)), output.tile((BLOCK_SIZE,)) + + +def application(lhs, rhs, output): + output = lhs + rhs + + +tensors = (Tensor(1), Tensor(1), Tensor(1)) +add_kernel = ninetoothed.make(arrangement, application, tensors) +在以上代码中,我们先是定义了三个张量的排布,即将三个参数张量全都分成形状为 (BLOCK_SIZE,) 的块,之后又定义了排布后张量的应用,即将 lhs 与 rhs 每一组对应的分块相加,之后写入到 output 所对应的分块当中。最后我们使用 ninetoothed.make 来将“张量”、“排布”、“应用”三个算法的组成部分进行整合,生成出可以运行的 add_kernel。需要注意的是:application 当中的 lhs、rhs、output 都是参数张量的每一组分块,而并非张量本身,如果使用上面提到的数值化示例,那 application 的参数,应当为形状为 (2, 2) 的分块,而不是形状为 (4, 8) 的原张量。 +有了 add_kernel,我们就可以直接使用以下方式进行调用: +import torch + +dtype = torch.float16 +device = "cuda" + +lhs = torch.tensor((1, 2, 3), dtype=dtype, device=device) +rhs = torch.tensor((4, 5, 6), dtype=dtype, device=device) +output = torch.empty_like(lhs) +add_kernel(lhs, rhs, output) +reference = torch.tensor((5, 7, 9), dtype=dtype, device=device) +assert torch.allclose(output, reference) +可以看到,我们在调用 add_kernel 时,并没有提供 BLOCK_SIZE 的实际取值。这是因为,在构造 BLOCK_SIZE 时,我们在 Symbol 中使用了 meta=True,这代表我们希望使用九齿所提供的配置组合来进行自动调优。如果我们希望人为提供取值(比如我们在进行调试时),我们可以使用 constexpr=True 来替代 meta=True,这样我们就可以使用以下方式传递具体的取值: +add_kernel(lhs, rhs, output, BLOCK_SIZE=1024) +索引和迭代 +九齿当中的张量并不局限于双层,也可以是三层甚至更多层,但是只有排布后张量的最外层会被用于启动并行程序。换句话说,三及以上层数的张量,在应用函数里,也是层级张量,是可以被索引和迭代的。 +让我们来通过一个稍微复杂一点的矩阵乘法的例子,来理解一下张量的索引和迭代,并进一步体会一下排布与应用范式: +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + + +def arrangement(lhs, rhs, output): + BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) + BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) + BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) + + output_arranged = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) + + lhs_arranged = lhs.tile((BLOCK_SIZE_M, BLOCK_SIZE_K)) + lhs_arranged = lhs_arranged.tile((1, -1)) + lhs_arranged = lhs_arranged.expand((-1, output_arranged.shape[1])) + lhs_arranged.dtype = lhs_arranged.dtype.squeeze(0) + + rhs_arranged = rhs.tile((BLOCK_SIZE_K, BLOCK_SIZE_N)) + rhs_arranged = rhs_arranged.tile((-1, 1)) + rhs_arranged = rhs_arranged.expand((output_arranged.shape[0], -1)) + rhs_arranged.dtype = rhs_arranged.dtype.squeeze(1) + + return lhs_arranged, rhs_arranged, output_arranged + + +def application(lhs, rhs, output): + accumulator = ntl.zeros(output.shape, dtype=ntl.float32) + for k in range(lhs.shape[0]): + accumulator += ntl.dot(lhs[k], rhs[k]) + output = accumulator.to(ntl.float16) + + +tensors = (Tensor(2), Tensor(2), Tensor(2)) +matmul_kernel = ninetoothed.make(arrangement, application, tensors) +可以看出,矩阵乘法的张量排布,要比向量加法复杂不少。为了辅助理解,以下是一张该分块算法的图示: +[图片] +在代码中,我们首先定义了 BLOCK_SIZE_M、BLOCK_SIZE_N、BLOCK_SIZE_K 三个符号,用于表示分块的形状。具体来讲,我们先将 output 矩阵 tile 成形状为 (BLOCK_SIZE_M, BLOCK_SIZE_N) 的块,将 lhs 矩阵 tile 成形状为 (BLOCK_SIZE_M, BLOCK_SIZE_K) 的块,并将 rhs 矩阵 tile 成形状为 (BLOCK_SIZE_K, BLOCK_SIZE_N) 的块: +[图片] +[图片] +[图片] +我们注意到,只进行分块对于矩阵乘法是不足的。按照上面的算法图示,output 当中的每一个分块,对应的是 lhs 的一行分块,与 rhs 的一列分块,所以我们还需要对 lhs 和 rhs 进行进一步的 tile,也就是将 lhs 的每一行 tile 在一起,和将 rhs 的每一列 tile 在一起: +[图片] +[图片] +但是这还并不是全部。还记得在进行张量排布时,我们最终需要做到什么嘛?没错,是使得各参数张量的分块能够对齐。再结合九齿的工作原理,排布后张量的最外层将被用于启动并行程序,我们可以引申出一条重要的结论:各参数张量排布后的最外层应当具有相同的形状。很明显,目前我们的三个张量的最外层,形状并不相同,这往往说明我们的排布并不正确,或者尚未完成。通过图示我们可以知道,我们需要将 lhs 的每一行分块,与 rhs 的每一列分块对齐,这一点我们可以通过广播来做到,也就是将 lhs 沿着横向 expand,将 rhs 沿着竖向 expand,均 expand 至与 output 有同样的形状: +[图片] +[图片] +至此,三个参数张量排布后的最外层,便具有了相同的形状。实际上,排布阶段可以在此停止,我们已经可以据此写出 application 函数,但是我们发现,刚才所分成的 lhs 的行分块和 rhs 的列分块是二维的,并且具有 (1, ...) 和 (..., 1) 这样形式的形状。也就是说,如果不进行其它操作,那么我们访问行分块和列分块的方式就得是 lhs[0, k] 和 rhs[k, 0],如果我们想要依靠 lhs 找到 k 的范围,那就需要通过 lhs.shape[1]。但是我们知道,大小为 1 的维度,在这种情况下完全可以被去掉,这就是为什么我们在最后加入了 squeeze 操作。这样,我们在访问行分块和列分块时就可以使用 lhs[k] 和 rhs[k],寻找 k 的范围时也可以使用 lhs.shape[0] 了。 +现在让我们来看 application 函数。在函数体当中,我们先定义了一个 accumulator,用于累加中间结果,之后就迭代了对应好的 lhs 的行块和 rhs 的列块,并且把他们相乘的结果累加到了 accumulator 当中,最后再将 accumulator 放到了对应的 output 的分块当中。由于参数张量被分成的每一块都被执行了这样的操作,因此对于整体而言,矩阵乘法就完成了。 +与向量加法相同,在定义好 arrangement 和 application 后,我们可以使用 ninetoothed.make 对它们进行整合,从而形成一个可以运行的 matmul_kernel。我们可以使用以下方式对其进行调用: +import torch + +dtype = torch.float16 +device = "cuda" + +lhs = torch.tensor(((1, 2), (3, 4)), dtype=dtype, device=device) +rhs = torch.tensor(((5, 6), (7, 8)), dtype=dtype, device=device) +output = torch.empty((lhs.shape[0], rhs.shape[1]), dtype=dtype, device=device) +matmul_kernel(lhs, rhs, output) +reference = torch.tensor(((19, 22), (43, 50)), dtype=dtype, device=device) +assert torch.allclose(output, reference) +这些就是九齿当中最核心的几个概念。 diff --git a/ninetoothed_puzzles.py b/ninetoothed_puzzles.py new file mode 100644 index 0000000..efaee7a --- /dev/null +++ b/ninetoothed_puzzles.py @@ -0,0 +1,411 @@ +# %% [markdown] +# # NineToothed Puzzles +# +# 九齿是一门张量级的深度学习领域特定语言(DSL),主要用途为开发计算内核。与 Triton 等传统并行编程语言相比,其通过引入面向张量的元编程,抽象掉了指针算术运算等底层细节,在保持性能与 Triton 相当的同时,提高了代码的可读性与可维护性。 +# +# 该笔记将带你由浅入深地学习和掌握九齿的使用方法。 + +# %% [markdown] +# ## 依赖 +# +# 首先,请确保 `ninetoothed` 等依赖已经安装好了。一般可以使用以下代码块进行安装,但是由于大家的环境可能各不相同,包管理工具也可能各不相同,所以以下代码块仅供参考,实际使用时请按需修改为合适的版本。 + +# %% +# %pip install ninetoothed +# %pip install ninetoothed[debugging] +# %pip install ninetoothed[visualization] + +# %% [markdown] +# 如果可以成功运行以下代码块,那就说明我们的环境已经准备好了。 + +# %% +import ninetoothed +import ninetoothed.language as ntl +import numpy as np +import torch +from ninetoothed import Symbol, Tensor +from ninetoothed.visualization import visualize + +assert torch.cuda.is_available(), "CUDA is not available." + +# %% [markdown] +# ## Quickstart +# +# 既然说九齿是一门张量级的 DSL,那么就让我们先来创建一个[张量](https://ninetoothed.org/python_api/tensor.html)看看。注:以下内容默认大家对张量有一个基本的了解,如果没有的话可以先参考[这篇 PyTorch 的教程](https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html)。 + +# %% +x = Tensor(shape=(2, 3)) + +# %% [markdown] +# 恭喜,我们刚刚成功创建了一个形状为 `(2, 3)` 的张量 `x`。让我们先试着调用它的一个方法:[`Tensor.eval`](https://ninetoothed.org/python_api/generated/ninetoothed.Tensor.eval.html#ninetoothed.Tensor.eval)。 + +# %% +x.eval() + +# %% [markdown] +# 在九齿当中,我们可以使用 `Tensor.eval` 对一个张量进行求值,从而可以将其打印出来。你或许会有疑问:为什么不能直接打印 `x`,还要调用 `Tensor.eval`?为了解答这个问题,让我们先来看看下面这个代码块。 + +# %% +x = Tensor(2) + +# %% [markdown] +# 我们会发现上面这行代码也可以顺利地完成。可是 `2` 在这里是什么意思呢?答案是维度数。以上代码也可以写成 `x = Tensor(ndim=2)`,换句话说,就是创建了一个矩阵(二维张量)`x`。如果大家对 PyTorch 或者 NumPy 很熟悉,就一定会好奇,单纯传递了维度数,怎么就能够创建出一个张量呢?正常情况下,不都得起码传递一个形状嘛?这是因为:九齿当中的张量是符号张量,其不存储实际数据,仅在 `Tensor.shape` 等成员变量中存储符号表达式。 + +# %% +x.shape + +# %% [markdown] +# 从以上代码块的输出我们可以看到 `x.shape` 的每个元素都是九齿生成的符号,这也印证了刚才所说的“九齿当中的张量是符号张量”。正因为如此,所以我们不能直接打印 `x`,而是需要先调用 `Tensor.eval`,而所谓求值,其实就是希望将一个不易可视化的符号张量,转化为一个更加看得见、摸得着的数值张量(目前的默认输出是 `numpy.ndarray`),其中每个元素都表示一个索引。 + +# %% +x.eval({x: Tensor(shape=(2, 3))}) + +# %% [markdown] +# 大家肯定注意到了我们这次调用加入了参数 `Tensor(shape=(2, 3))`,这是因为在进行求值时,我们必须提供数值化所需的全部信息。由于 `x` 的形状是未知的,所以我们必须要手动传值,确保每个符号都有对应的数值代入,才可以完成这一过程。 + +# %% +subs = {x: Tensor(shape=(8, 8))} +x.eval(subs) + +# %% [markdown] +# 根据代入的数值不同,`Tensor.eval` 的结果自然也可以是不同的。 + +# %% +x_substituted = x.subs(subs) +x_substituted, x_substituted.shape + +# %% [markdown] +# 九齿当中还提供 [`Tensor.subs`](https://ninetoothed.org/python_api/generated/ninetoothed.Tensor.subs.html#ninetoothed.Tensor.subs) 函数,可以将一个九齿张量数值化,注意 `Tensor.subs` 的输出与 `Tensor.eval` 不同,`Tensor.eval` 的输出是一个非嵌套数值张量,里面存储着索引数据,而 `Tensor.subs` 的输出仍然是九齿张量,只不过其中如 `shape` 等属性中的符号会被替换为数值。 + +# %% +x_substituted.eval() + +# %% [markdown] +# 单独提供 `Tensor.subs` 的好处就是可以方便使用九齿当中一些专门给九齿数值化张量使用的工具,如 `visualize` 等,这些在后文当中也会逐步介绍。由于 `x_substituted` 已经是数值化了的,所以后面的 `Tensor.eval` 中就不需要再传入 `subs` 了。 + +# %% [markdown] +# 很好,我们现在掌握了可以将九齿当中本来相对抽象的符号张量转化为比较具象的数值张量的方法,接下来就可以尝试对张量进行一些操作了。就让我们从 [`Tensor.tile`](https://ninetoothed.org/python_api/generated/ninetoothed.Tensor.tile.html#ninetoothed.Tensor.tile) 开始吧,这也是九齿当中最为核心的操作。 + +# %% +x_tiled = x.tile((4, 4)) + +# %% [markdown] +# 上面这行代码的意思是:将 `x` 分成每一块大小为 `(4, 4)` 的块。如果按照上面 `x` 的形状为 `(8, 8)` 的话,按理说就可以分成 `(2, 2)` 块。 + +# %% +x_tiled_substituted = x_tiled.subs(subs) +x_tiled_substituted.eval() + +# %% [markdown] +# 可以看到求值后的张量的维度数从原先的 `2` 变成了 `4`,这是因为 `Tensor.tile` 操作会生成嵌套张量。用上面的例子来说,结果就是一个双层的张量,其中外层的形状为 `(2, 2)`,内层的形状为 `(4, 4)`,所以求值后的形状就是 `(2, 2, 4, 4)`。为了能够更清晰地理解嵌套张量,我们可以使用 [`visualize`](https://ninetoothed.org/python_api/visualization.html#ninetoothed.visualization.visualize) 来可视化一个张量。 + +# %% +visualize(x_tiled_substituted) + +# %% [markdown] +# 如图所示,该张量是一个双层张量,其中外层的形状为 `(2, 2)`,内层的形状为 `(4, 4)`。也可以说,外层的每个元素,都是一个 `(4, 4)` 的内层张量。 + +# %% +x_tiled.dtype + +# %% [markdown] +# 我们可以通过 `Tensor.dtype` 来访问内层张量,因为嵌套张量的元素类型已经不仅仅局限于 `float`、`int` 等,也可以是 `Tensor`。这样的好处是我们可以方便地对其中某一层进行操作,比如对 `x_tiled` 的内层 `x_tiled.dtype` 进行 [`flatten`](https://ninetoothed.org/python_api/generated/ninetoothed.Tensor.flatten.html#ninetoothed.Tensor.flatten) 操作。 + +# %% +x_tiled.dtype = x_tiled.dtype.flatten() +x_tiled_substituted = x_tiled.subs(subs) +visualize(x_tiled_substituted) +x_tiled_substituted.eval() + +# %% [markdown] +# 如果我们集中注意力,会发现上述代码中调用 `Tensor.tile` 时传入的是 `(4, 4)`,也就是确定的数值,这看起来没什么问题,但是在九齿面向张量的元编程中,却的确有些格格不入。实际上,我们可以定义[符号](https://ninetoothed.org/python_api/symbol.html)来进行操作,而非必须使用具体的数值。 + +# %% +block_size_m = Symbol("block_size_m", constexpr=True) +block_size_n = Symbol("block_size_n", constexpr=True) + +x_tiled = x.tile((block_size_m, block_size_n)) + +subs |= {block_size_m: 4, block_size_n: 4} + +x_tiled_substituted = x_tiled.subs(subs) +visualize(x_tiled_substituted) +x_tiled_substituted.eval() + +# %% [markdown] +# 在九齿中,我们可以通过 [`Symbol`](https://ninetoothed.org/python_api/symbol.html#symbol) 来定义符号(先不用管 `constexpr=True`,后面会讲),从而使用符号来参与张量操作。但是需要注意,由于我们引入了符号,所以在 `subs` 中我们就需要再加上 `block_size_m` 和 `block_size_n` 的取值,才能保证代入的完整性。 + +# %% [markdown] +# 很好,我们现在学会了如何对张量进行操作。我们把一系列这样的操作,称之为排布。既然可以对一个张量进行排布,那当然也可以对多个张量进行排布。 + +# %% +block_size = Symbol("block_size", constexpr=True) + + +def arrangement(x, y, z, block_size=block_size): + return x.tile((block_size,)), y.tile((block_size,)), z.tile((block_size,)) + + +# %% [markdown] +# 在以上的函数中,我们便分别对三个参数张量 `x`、`y`、`z` 进行了排布。 + +# %% +x = Tensor(1) +y = Tensor(1) +z = Tensor(1) + +x_arranged, y_arranged, z_arranged = arrangement(x, y, z) + +shape = (8,) +subs = { + x: Tensor(shape=shape), + y: Tensor(shape=shape), + z: Tensor(shape=shape), + block_size: 4, +} + +x_arranged_substituted = x_arranged.subs(subs) +y_arranged_substituted = y_arranged.subs(subs) +z_arranged_substituted = z_arranged.subs(subs) + +print(x_arranged_substituted.shape, x_arranged_substituted.dtype.shape) +print(y_arranged_substituted.shape, y_arranged_substituted.dtype.shape) +print(z_arranged_substituted.shape, z_arranged_substituted.dtype.shape) + +visualize(x_arranged_substituted) +visualize(y_arranged_substituted) +visualize(z_arranged_substituted) + +x_arranged_evaluated = x_arranged_substituted.eval() +y_arranged_evaluated = y_arranged_substituted.eval() +z_arranged_evaluated = z_arranged_substituted.eval() + +print(x_arranged_evaluated) +print(y_arranged_evaluated) +print(z_arranged_evaluated) + +# %% [markdown] +# 不难看出,如果输入张量的形状都为 `(8,)`,`block_size` 为 `4`,则输出张量均为双层张量,其中外层形状为 `(2,)`,内层形状为 `(4,)`。通过观察发现,我们可以通过排布将一个源张量变换为一个多层的嵌套张量。同理,我们也可以通过排布将多个源张量变换为多个多层的嵌套张量。 +# +# 九齿的运行机制也由此而生:**九齿会根据各个参数张量排布后的最外层张量的形状启动程序实例,并把次外层张量映射到这些程序实例上。** +# +# 所以,如果按照上述的排布,九齿就会启动 `2` 个程序实例,并把排布后的 `x`、`y`、`z` 的最外层张量的每个元素,也就是次外层张量,与这 `2` 个程序实例一一对应。 + +# %% +print("Program instance 0:") +print("x:", x_arranged_evaluated[0]) +print("y:", y_arranged_evaluated[0]) +print("z:", z_arranged_evaluated[0]) + +print("-" * 32) + +print("Program instance 1:") +print("x:", x_arranged_evaluated[1]) +print("y:", y_arranged_evaluated[1]) +print("z:", z_arranged_evaluated[1]) + +# %% [markdown] +# 从以上输出不难看出,`x`、`y`、`z` 三个张量在程序实例 `0` 上分得的都是 `[0 1 2 3]`,在程序实例 `1` 上分得的则都是 `[4 5 6 7]`。这里我们就可以更清晰地明白 `Tensor.eval` 后张量中存储索引的用途:建立张量排布前后元素的对应关系。我们通过以上的打印方式,可以看到每个程序实例上,各个参数张量的次外层张量的对应关系。我们也可以按照以下方式打印,这样就可以清晰地看出一个张量的元素在各个程序实例上的分布。总而言之,熟练地使用 `Tensor.eval` 和 `Tensor.subs` 等工具,非常有助于对九齿的理解和使用。 + +# %% +print("x:", x.eval(subs)) +print("x at program instance 0:", x_arranged_evaluated[0]) +print("x at program instance 1:", x_arranged_evaluated[1]) + + +# %% [markdown] +# 这就是九齿编程模型的并行原理。但是光有排布还不够,因为虽然我们已经将参数张量分到了各个程序实例上,但是每一个程序实例要做什么,我们还没有定义。在九齿中,我们可以通过定义应用函数来告诉九齿每个程序实例需要做什么。 + + +# %% +def application(x, y, z): + z = x + y # noqa: F841 + + +# %% [markdown] +# 上面的应用函数代码逻辑很简单,就是把 `x` 和 `y` 相加,并把结果放入 `z` 中。但是需要注意的是:应用函数的参数,是参数张量排布后的最外层张量的元素,也就是次外层张量,而不是张量本身。也就是说,如果套用上面的假设,这里的 `x`、`y`、`z` 都是指长度为 `4` 的块,而不是长度为 `8` 的原本的张量。 +# +# 很好,我们现在有了一个排布函数 `arrangement` 和一个应用函数 `application`,接下来就可以将它们整合,从而形成一个完整可运行的计算内核。 + +# %% +kernel = ninetoothed.make(arrangement, application, (Tensor(1), Tensor(1), Tensor(1))) + +# %% [markdown] +# 这段代码的意思就是说,我想要按照 `arrangement` 函数对三个一维张量,也就是向量,进行排布,并按照 `application` 函数应用排布后的张量,最终形成一个计算内核 `kernel`。我们把这样构造计算内核的范式,称之为排布与应用范式。 +# +# 我们可以如下所示对 `kernel` 进行调用: + +# %% +size = 240620 +device = "cuda" + +x = torch.randn(size, device=device) +y = torch.randn(size, device=device) +z = torch.empty_like(x) + +kernel(x, y, z, block_size=64) + +print(x) +print(y) +print(z) + +reference = x + y + +print(reference) + +assert torch.allclose(z, reference) + +# %% [markdown] +# 我们不难发现,上面实现出的 `kernel`,其实就是一个向量加法的计算内核。所以说,使用九齿实现向量加法,实际只需要以下几行即可。 + +# %% +block_size = Symbol("block_size", constexpr=True) + + +def arrangement(x, y, z, block_size=block_size): + return x.tile((block_size,)), y.tile((block_size,)), z.tile((block_size,)) + + +def application(x, y, z): + z = x + y # noqa: F841 + + +kernel = ninetoothed.make(arrangement, application, (Tensor(1), Tensor(1), Tensor(1))) + +# %% [markdown] +# 现在让我们来看一下 `Symbol` 中的 `constexpr=True`。对 C++ 熟悉的小伙伴应该对 `constexpr` 不陌生,其在 C++ 中表示编译时常量,这也是 `constexpr` 在九齿中的含义。换句话说,`Symbol("block_size", constexpr=True)` 表示我们希望创建一个编译时确定取值的符号。之所以如此,是因为九齿当前继承了 Triton 的一个约束:最内层张量的形状需要在编译时是确定的。这就是为什么之前调用 `kernel` 时我们传递了 `block_size=64`:因为 JIT 编译时需要知道 `block_size`。我们还可以交由九齿来选择具体的取值,就像 Triton 当中有 `triton.autotune` 一样,九齿也提供自动调优功能。 + +# %% +block_size = Symbol( + "block_size", meta=True, lower_bound=32, upper_bound=128, power_of_two=True +) + +# %% [markdown] +# `meta=True` 的意思就是,我希望创建一个元符号,即将该符号的具体取值交由九齿决定,当然我们也需要提供一些信息,比如这个符号的取值范围,以及它是否为 2 的幂之类的。 + +# %% [markdown] +# 由于 block size 的创建太高频,几乎是所有计算内核必要的,所以九齿提供一个很好用的函数 `ninetoothed.block_size`,专门用来定义 block size,使用它时九齿将自动选择合适的配置进行自动调优。 + +# %% +block_size = ninetoothed.block_size() + + +# %% [markdown] +# 让我们尝试使用它来定义和运行计算内核,这次调用 `kernel` 时便不再需要传递 `block_size` 了,因为我们已经告诉了九齿,我们希望通过自动调优来找到合适的 `block_size` 取值。 +# +# 注:由于自动调优需要时间,所以在后面的部分我们会统一使用 `constexpr` 符号。事实上,在刚开始开发和调试某一计算内核时,也建议先使用 `constexpr`,一方面加快原型验证的速度,一方面有确定的取值也有助于调试。可以在调试完成后需要性能时再打开自动调优,就好比 Debug 和 Release 模式一样。 + + +# %% +def arrangement(x, y, z, block_size=block_size): + return x.tile((block_size,)), y.tile((block_size,)), z.tile((block_size,)) + + +def application(x, y, z): + z = x + y # noqa: F841 + + +kernel = ninetoothed.make(arrangement, application, (Tensor(1), Tensor(1), Tensor(1))) + +size = 240620 +device = "cuda" + +x = torch.randn(size, device=device) +y = torch.randn(size, device=device) +z = torch.empty_like(x) + +kernel(x, y, z) + +print(x) +print(y) +print(z) + +reference = x + y + +print(reference) + +assert torch.allclose(z, reference) + +# %% [markdown] +# 在向量加法中,参数张量经过排布变成了双层的张量,但是九齿当中的张量并不局限于双层,也可以是三层甚至更多层。 + +# %% +x = Tensor(2) + +x_arranged = x.tile((1, block_size)) +x_arranged = x_arranged.tile((1, -1)) + +subs = {x: Tensor(shape=(4, 8)), block_size: 4} + +x_arranged_substituted = x_arranged.subs(subs) +visualize(x_arranged_substituted) +x_arranged_evaluated = x_arranged_substituted.eval() +x_arranged_evaluated + +# %% [markdown] +# 可以看出,以上代码构造出了一个三层的张量。具体而言,上述代码先是将 `x` `tile` 成了形状为 `(1, block_size)` 的若干块,也就是每一行若干块,之后又将每一行分块 `tile` 在了一起(跟很多 PyTorch 函数一样,`-1` 在 `tile` 中表示维度原本的大小)。这里需要注意一点:只有排布后的最外层会被用于启动程序实例。换句话说,三及以上层的张量,在应用函数里,也是层级张量,是可以被索引和迭代的。 + +# %% +for pid, index in enumerate(np.ndindex(x_arranged_substituted.shape)): + print(f"x at program instance {pid}:") + for idx in np.ndindex(x_arranged_substituted.dtype.shape): + print(x_arranged_evaluated[index][idx]) + +# %% [markdown] +# 从上面的输出我们可以看出,每个程序实例中都得进一步迭代,才能得到最内层形状为 `(1, 4)` 的张量。为了进一步理解,让我们基于上述排布,完整地实现一个计算内核看看。 + +# %% +block_size = Symbol("block_size", constexpr=True) + + +def arrangement(x, y, block_size=block_size): + x_arranged = x.tile((1, block_size)) + x_arranged = x_arranged.tile((1, -1)) + + y_arranged = y.tile((1, 1)) + + return x_arranged, y_arranged + + +def application(x, y): + acc = ntl.zeros(y.shape, dtype=y.dtype) + + for i in range(x.shape[1]): + acc += ntl.sum(x[0, i], axis=-1) + + y = acc + + +kernel = ninetoothed.make(arrangement, application, (Tensor(2, other=0), Tensor(2))) + +# %% [markdown] +# 用自然语言来描述的话,上述计算内核做的事情就是把 `x` 的每一行分块,并且对每个分块求和,再将求和结果累加在一起,然后存入 `y` 对应的分块中。换句话说,就是把 `x` 每行的和存入 `y` 中。大家可能会想,那不是 `tile` 一次就行了,何必搞两次,还要 `for` 一下,这是因为 `block_size` 是有大小限制的,不可以无限大,所以如果 `x` 的一行特别长,就可能会超出限制,这个时候就需要把一行分为多块。事实上如果已知输入 `x` 的列数不可能过大,那么只 `tile` 一次再直接 `ntl.sum` 是完全可以的。顺带一提,这里提供的求和只是为了方便大家理解应用函数中的迭代,并不一定是很高效的实现。 +# +# 大家可能注意到了第一个 `Tensor` 中的 `other=0`,这个指的是越界情况下的取值。那什么情况下可能会越界呢?比如如果 `x` 的列数无法被 `block_size` 整除,这时就会有一个程序实例所处理的分块实际上是越过了边界的。之所以这里设置为了 `0`,是因为我们想要求和。具体情况需要具体分析。比如如果我们希望求最大值,那可能就需要设为 `float("-inf")` 了。 +# +# 有些小伙伴可能会对 `application` 中的 `ntl.zeros` 和 `ntl.sum` 感兴趣,希望知道什么函数在应用函数内可以使用,什么不可以,那么这里就可以参考 [Triton 的文档](https://triton-lang.org/main/python-api/triton.language.html),基本上 `triton.language` 中有的,都可以通过 `ninetoothed.language` 来调用,这就像 C++ 中的 `` 和 C 中的 ``,或者说 C++ 中的 `std::size_t` 和 C 中的 `size_t` 一样。当然,九齿也会添加额外的东西,比如 `Tensor.shape`、`Tensor.dtype`、`Tensor.offsets` 等。 +# +# 好,那接下来,让我们运行以下代码来验证一下刚刚定义的计算内核。 + +# %% +m = 240 +n = 620 + +x = torch.randn(m, n, device=device) +y = torch.empty((m, 1), device=device) + +kernel(x, y, block_size=64) + +print(x) +print(y) + +reference = torch.sum(x, dim=-1, keepdim=True) + +print(reference) + +assert torch.allclose(y, reference, atol=1e-5) + +# %% [markdown] +# ## 致谢 +# +# 本项目受到了 [Triton-Puzzles](https://github.com/srush/Triton-Puzzles) 的启发。 \ No newline at end of file diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..e7a8527 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -15,6 +15,7 @@ ge, gelu, gt, + index_add, isinf, isnan, layer_norm, @@ -24,6 +25,8 @@ mul, ne, neg, + nonzero_sum_gt_last2, + one_hot, pow, relu, rms_norm, @@ -36,6 +39,8 @@ softmax, sub, tanh, + topk, + where, ) __all__ = [ @@ -55,6 +60,7 @@ "ge", "gelu", "gt", + "index_add", "isinf", "isnan", "layer_norm", @@ -64,6 +70,8 @@ "mul", "ne", "neg", + "nonzero_sum_gt_last2", + "one_hot", "pow", "relu", "rms_norm", @@ -76,4 +84,6 @@ "softmax", "sub", "tanh", + "topk", + "where", ] diff --git a/src/ntops/kernels/index_add.py b/src/ntops/kernels/index_add.py new file mode 100644 index 0000000..d1f0187 --- /dev/null +++ b/src/ntops/kernels/index_add.py @@ -0,0 +1,81 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement as reduction_arrangement + + +def arrangement(input, index, source, alpha, output, dim, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + if isinstance(dim, Tensor): + dim = dim.value + + if dim < 0: + dim += input.ndim + + index_expanded = index + for _ in range(input.ndim - 1): + index_expanded = index_expanded.unsqueeze(0) + + if dim != input.ndim - 1: + permute_order = list(range(index_expanded.ndim)) + last = permute_order.pop(-1) + permute_order.insert(dim, last) + index_expanded = index_expanded.permute(tuple(permute_order)) + + expand_shape = list(source.shape) + expand_shape[dim] = -1 + index_expanded = index_expanded.expand(tuple(expand_shape)) + + input_arranged, index_arranged, source_arranged, output_arranged = ( + reduction_arrangement( + input, index_expanded, source, output, dim=dim, block_size=block_size + ) + ) + + return input_arranged, index_arranged, source_arranged, alpha, output_arranged + + +def _application_dim0(input, index, source, alpha, output): + index_dtype = ntl.int64 + output_dtype = output.dtype.dtype + alpha_cast = ntl.cast(alpha, output_dtype) + + zero_index = ntl.cast(0, index_dtype) + zero_out = ntl.cast(0, output_dtype) + dim_size = ntl.cast(output.source.shape[0], index_dtype) + + for out_block in range(output.shape[0]): + out_vals = ntl.cast(input[out_block], output_dtype) + out_positions = ntl.cast(output[out_block].offsets(0), index_dtype) + valid_out = (out_positions >= zero_index) & (out_positions < dim_size) + + for src_block in range(source.shape[0]): + idx_block = ntl.cast(index[src_block], index_dtype) + src_vals = ntl.cast(source[src_block], output_dtype) + matches = out_positions[:, None] == idx_block[None, :] + contrib = ntl.sum(ntl.where(matches, src_vals[None, :], zero_out), 1) + out_vals += alpha_cast * contrib + + output[out_block] = ntl.where(valid_out, out_vals, zero_out) + + +def premake(ndim, dim, dtype=None, block_size=None): + if dim != 0: + raise ValueError("Only dim=0 is supported for index_add.") + + arrangement_ = functools.partial(arrangement, dim=0, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, other=0), + Tensor(1, dtype=ninetoothed.int64, other=-1), + Tensor(ndim, dtype=dtype, other=0), + Tensor(0, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, _application_dim0, tensors diff --git a/src/ntops/kernels/nonzero_sum_gt_last2.py b/src/ntops/kernels/nonzero_sum_gt_last2.py new file mode 100644 index 0000000..217a336 --- /dev/null +++ b/src/ntops/kernels/nonzero_sum_gt_last2.py @@ -0,0 +1,65 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + ndim = input.ndim + if ndim < 2: + raise ValueError("nonzero_sum_gt_last2 requires ndim >= 2") + + non_target_dims = tuple(range(ndim - 2)) + reduce_dims = (ndim - 2, ndim - 1) + + input_arranged = input.permute(non_target_dims + reduce_dims) + input_arranged = input_arranged.flatten(start_dim=-2) + + inner_block_shape = tuple(1 for _ in non_target_dims) + (block_size,) + outer_block_shape = tuple(1 for _ in non_target_dims) + (-1,) + + input_arranged = input_arranged.tile(inner_block_shape) + input_arranged = input_arranged.tile(outer_block_shape) + input_arranged.dtype = input_arranged.dtype.squeeze(tuple(range(len(non_target_dims)))) + input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze( + tuple(range(len(non_target_dims))) + ) + + output_arranged = output.permute(non_target_dims + reduce_dims) + output_arranged = output_arranged.flatten(start_dim=-2) + output_arranged = output_arranged.tile(tuple(1 for _ in non_target_dims) + (1,)) + output_arranged.dtype = output_arranged.dtype.squeeze(tuple(range(len(non_target_dims)))) + + return input_arranged, output_arranged + + +def application(input, output): + acc = ntl.cast(0, ntl.float32) + + for i in range(input.shape[0]): + acc += ntl.sum(ntl.cast(input[i], ntl.float32)) + + output_dtype = output.dtype.dtype + is_positive = acc > ntl.cast(0, ntl.float32) + output[0] = ntl.where( + is_positive, ntl.cast(1, output_dtype), ntl.cast(0, output_dtype) + ) + + +def premake(ndim, dtype=None, block_size=None): + if ndim < 2: + raise ValueError("nonzero_sum_gt_last2 requires ndim >= 2") + + arrangement_ = functools.partial(arrangement, block_size=block_size) + + input = Tensor(ndim, dtype=dtype, other=0) + output = Tensor(ndim, dtype=dtype) + output.shape = input.shape[:-2] + (1, 1) + + tensors = (input, output) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/one_hot.py b/src/ntops/kernels/one_hot.py new file mode 100644 index 0000000..7db10ca --- /dev/null +++ b/src/ntops/kernels/one_hot.py @@ -0,0 +1,56 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def _next_power_of_2(value): + if value < 1: + raise ValueError("`value` must be positive.") + return 1 << (value - 1).bit_length() + + +def arrangement(input, output, class_block_size, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input_flat = input.flatten() + input_flat = input_flat.unsqueeze(1) + + output_flat = output.flatten(end_dim=-1) + + input_arranged = input_flat.tile((block_size, 1)) + output_arranged = output_flat.tile((block_size, class_block_size)) + + return input_arranged, output_arranged + + +def application(input, output): + index_dtype = ntl.int64 + output_dtype = output.dtype + + input_values = ntl.cast(input, index_dtype) + class_indices = ntl.cast(output.offsets(-1), index_dtype) + + output = ntl.where( + input_values == class_indices, + ntl.cast(1, output_dtype), + ntl.cast(0, output_dtype), + ) + + +def premake(ndim, num_classes, block_size=None): + class_block_size = _next_power_of_2(num_classes) + arrangement_ = functools.partial( + arrangement, block_size=block_size, class_block_size=class_block_size + ) + + input = Tensor(ndim, dtype=ninetoothed.int64, other=-1) + output = Tensor(ndim + 1, dtype=ninetoothed.int64) + + output.shape = input.shape + (num_classes,) + + tensors = (input, output) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/topk.py b/src/ntops/kernels/topk.py new file mode 100644 index 0000000..9cd0942 --- /dev/null +++ b/src/ntops/kernels/topk.py @@ -0,0 +1,119 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement as reduction_arrangement + + +def _next_power_of_2(value): + if value < 1: + raise ValueError("`value` must be positive.") + return 1 << (value - 1).bit_length() + + +def arrangement( + input, + dim_size, + k_constexpr, + values, + indices, + dim, + block_size=None, + output_block_size=None, +): + if block_size is None: + block_size = ninetoothed.block_size() + + if dim != -1: + raise ValueError("Only dim=-1 is supported for topk.") + + dim = input.ndim - 1 + + input_arranged = reduction_arrangement(input, dim=dim, block_size=block_size)[0] + + if output_block_size is None: + output_block_size = values.shape[dim] + values_arranged = reduction_arrangement( + values, dim=dim, block_size=output_block_size + )[0] + indices_arranged = reduction_arrangement( + indices, dim=dim, block_size=output_block_size + )[0] + + return input_arranged, dim_size, k_constexpr, values_arranged, indices_arranged + + +def _application_last(input, dim_size, k_constexpr, values, indices): + value_dtype = values.dtype.dtype + index_dtype = ntl.int64 + + dim_size_ = ntl.cast(dim_size, index_dtype) + + k = k_constexpr + + neg_inf = ntl.cast(float("-inf"), value_dtype) + neg_one = ntl.cast(-1, index_dtype) + + top_vals = ntl.full(values.dtype.shape, float("-inf"), dtype=value_dtype) + top_indices = ntl.full(indices.dtype.shape, -1, dtype=index_dtype) + positions = ntl.cast(values[0].offsets(-1), index_dtype) + + for t in range(k): + best_val = neg_inf + best_idx = neg_one + + for block in range(input.shape[0]): + block_vals = input[block] + block_indices = ntl.cast(input[block].offsets(-1), index_dtype) + + valid = block_indices < dim_size_ + for prev in range(t): + prev_idx = ntl.max(ntl.where(positions == prev, top_indices, neg_one)) + valid = valid & (block_indices != prev_idx) + + masked_vals = ntl.where(valid, block_vals, neg_inf) + block_best_val = ntl.cast(ntl.max(masked_vals), value_dtype) + block_best_idx = ntl.max( + ntl.where(valid & (masked_vals == block_best_val), block_indices, neg_one) + ) + + better = block_best_val > best_val + best_val = ntl.where(better, block_best_val, best_val) + best_idx = ntl.where(better, block_best_idx, best_idx) + + write_mask = positions == t + top_vals = ntl.where(write_mask, best_val, top_vals) + top_indices = ntl.where(write_mask, best_idx, top_indices) + + values[0] = top_vals + indices[0] = top_indices + + +def premake(ndim, dim, k, dtype=None, block_size=None): + if dim != -1: + raise ValueError("Only dim=-1 is supported for topk.") + + input = Tensor(ndim, dtype=dtype, other=float("-inf")) + dim_size = Tensor(0, dtype=ninetoothed.int64) + k_constexpr = Tensor(0, dtype=ninetoothed.int64, constexpr=True, value=k) + values = Tensor(ndim, dtype=dtype) + indices = Tensor(ndim, dtype=ninetoothed.int64) + + dim = ndim - 1 + + output_block_size = _next_power_of_2(k) + arrangement_ = functools.partial( + arrangement, + dim=-1, + block_size=block_size, + output_block_size=output_block_size, + ) + + values.shape = values.shape[:dim] + (k,) + values.shape[dim + 1 :] + indices.shape = indices.shape[:dim] + (k,) + indices.shape[dim + 1 :] + + tensors = (input, dim_size, k_constexpr, values, indices) + + return arrangement_, _application_last, tensors diff --git a/src/ntops/kernels/where.py b/src/ntops/kernels/where.py new file mode 100644 index 0000000..c6ad996 --- /dev/null +++ b/src/ntops/kernels/where.py @@ -0,0 +1,23 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(condition, input, other, output): + output = ntl.where(condition, input, other) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..3156294 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -14,6 +14,7 @@ from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt +from ntops.torch.index_add import index_add from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan from ntops.torch.layer_norm import layer_norm @@ -24,6 +25,8 @@ from ntops.torch.mul import mul from ntops.torch.ne import ne from ntops.torch.neg import neg +from ntops.torch.nonzero_sum_gt_last2 import nonzero_sum_gt_last2 +from ntops.torch.one_hot import one_hot from ntops.torch.pow import pow from ntops.torch.relu import relu from ntops.torch.rms_norm import rms_norm @@ -36,6 +39,8 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.topk import topk +from ntops.torch.where import where __all__ = [ "abs", @@ -54,6 +59,7 @@ "ge", "gelu", "gt", + "index_add", "isinf", "isnan", "layer_norm", @@ -64,6 +70,8 @@ "mul", "ne", "neg", + "nonzero_sum_gt_last2", + "one_hot", "pow", "relu", "rms_norm", @@ -76,4 +84,6 @@ "softmax", "sub", "tanh", + "topk", + "where", ] diff --git a/src/ntops/torch/index_add.py b/src/ntops/torch/index_add.py new file mode 100644 index 0000000..fc9a420 --- /dev/null +++ b/src/ntops/torch/index_add.py @@ -0,0 +1,23 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def index_add(input, dim, index, source, *, alpha=1, out=None): + if index.dtype != torch.int64: + raise AssertionError( + "index_add is only applicable to index tensor of type LongTensor." + ) + + if dim != 0: + raise AssertionError("Only dim=0 is supported.") + + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.index_add.premake, input.ndim, dim) + + kernel(input, index, source, alpha, out) + + return out diff --git a/src/ntops/torch/nonzero_sum_gt_last2.py b/src/ntops/torch/nonzero_sum_gt_last2.py new file mode 100644 index 0000000..fd93f5c --- /dev/null +++ b/src/ntops/torch/nonzero_sum_gt_last2.py @@ -0,0 +1,26 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def nonzero_sum_gt_last2(input): + if input.ndim < 2: + raise AssertionError("nonzero_sum_gt_last2 requires input.ndim >= 2") + + if input.dtype is torch.bool: + input_for_kernel = input.to(torch.int8) + else: + input_for_kernel = input + + output_shape = tuple(input_for_kernel.shape[:-2]) + (1, 1) + output = torch.empty( + output_shape, device=input_for_kernel.device, dtype=input_for_kernel.dtype + ) + + kernel = _cached_make(ntops.kernels.nonzero_sum_gt_last2.premake, input.ndim) + + kernel(input_for_kernel, output) + + mask = output.squeeze(-1).squeeze(-1) + return mask.nonzero() diff --git a/src/ntops/torch/one_hot.py b/src/ntops/torch/one_hot.py new file mode 100644 index 0000000..ff22dcf --- /dev/null +++ b/src/ntops/torch/one_hot.py @@ -0,0 +1,42 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def one_hot(input, num_classes=-1): + if input.dtype != torch.int64: + raise AssertionError( + "one_hot is only applicable to index tensor of type LongTensor." + ) + + if input.numel() == 0: + if num_classes is None or num_classes == -1: + raise ValueError( + "Can not infer total number of classes from empty tensor." + ) + num_classes = int(num_classes) + if num_classes <= 0: + raise ValueError("`num_classes` must be positive.") + else: + min_value = int(input.min().item()) + if min_value < 0: + raise ValueError("Class values must be non-negative.") + + if num_classes is None or num_classes == -1: + num_classes = int(input.max().item()) + 1 + else: + num_classes = int(num_classes) + if num_classes <= 0: + raise ValueError("`num_classes` must be positive.") + if int(input.max().item()) >= num_classes: + raise ValueError("Class values must be smaller than num_classes.") + + output_shape = tuple(input.shape) + (num_classes,) + output = torch.empty(output_shape, dtype=torch.int64, device=input.device) + + kernel = _cached_make(ntops.kernels.one_hot.premake, input.ndim, num_classes) + + kernel(input, output) + + return output diff --git a/src/ntops/torch/topk.py b/src/ntops/torch/topk.py new file mode 100644 index 0000000..21155f7 --- /dev/null +++ b/src/ntops/torch/topk.py @@ -0,0 +1,31 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def topk(input, k, dim=-1, largest=True, sorted=True): + if not largest: + raise AssertionError("Only largest=True is supported.") + + if not sorted: + raise AssertionError("Only sorted=True is supported.") + + if dim != -1: + raise AssertionError("Only dim=-1 is supported.") + + assert 0 < k <= input.shape[dim], "`k` must be in (0, input.shape[dim]]." + + output_shape = list(input.shape) + output_shape[dim] = k + + values = torch.empty(output_shape, device=input.device, dtype=input.dtype) + indices = torch.empty(output_shape, device=input.device, dtype=torch.int64) + + dim_size = input.shape[dim] + + kernel = _cached_make(ntops.kernels.topk.premake, input.ndim, dim, k) + + kernel(input, dim_size, k, values, indices) + + return values, indices diff --git a/src/ntops/torch/where.py b/src/ntops/torch/where.py new file mode 100644 index 0000000..1828be2 --- /dev/null +++ b/src/ntops/torch/where.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def where(condition, input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.where.premake, input.ndim) + + kernel(condition, input, other, out) + + return out diff --git a/tests/test_index_add.py b/tests/test_index_add.py new file mode 100644 index 0000000..4f931b2 --- /dev/null +++ b/tests/test_index_add.py @@ -0,0 +1,30 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import gauss, generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_index_add(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + dim = 0 + dim_size = input.shape[dim] + index_size = random.randint(1, dim_size) + + index = torch.randint(0, dim_size, (index_size,), device=device, dtype=torch.int64) + + source_shape = list(shape) + source_shape[dim] = index_size + source = torch.randn(source_shape, dtype=dtype, device=device) + + alpha = gauss() + + ninetoothed_output = ntops.torch.index_add(input, dim, index, source, alpha=alpha) + reference_output = torch.index_add(input, dim, index, source, alpha=alpha) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_nonzero_sum_gt_last2.py b/tests/test_nonzero_sum_gt_last2.py new file mode 100644 index 0000000..78a77b1 --- /dev/null +++ b/tests/test_nonzero_sum_gt_last2.py @@ -0,0 +1,23 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +def _make_mask(shape, dtype, device): + base = torch.randint(0, 2, shape, device=device, dtype=torch.int32) + return base.to(dtype) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape", [(2, 3), (2, 5, 7), (3, 4, 5, 6)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.int32, torch.bool]) +def test_nonzero_sum_gt_last2(shape, dtype): + device = "cuda" + input = _make_mask(shape, dtype, device) + + ninetoothed_output = ntops.torch.nonzero_sum_gt_last2(input) + reference_output = torch.greater(input.sum(dim=(-1, -2)), 0).nonzero() + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_one_hot.py b/tests/test_one_hot.py new file mode 100644 index 0000000..45f2635 --- /dev/null +++ b/tests/test_one_hot.py @@ -0,0 +1,25 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape", [(4,), (2, 3), (2, 2, 3), (2, 1, 3, 4)]) +@pytest.mark.parametrize("num_classes", [-1, 5, 16]) +def test_one_hot(shape, num_classes): + device = "cuda" + dtype = torch.int64 + + if num_classes == -1: + max_class = 7 + input = torch.randint(0, max_class, size=shape, device=device, dtype=dtype) + ninetoothed_output = ntops.torch.one_hot(input) + reference_output = torch.nn.functional.one_hot(input) + else: + input = torch.randint(0, num_classes, size=shape, device=device, dtype=dtype) + ninetoothed_output = ntops.torch.one_hot(input, num_classes) + reference_output = torch.nn.functional.one_hot(input, num_classes) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_topk.py b/tests/test_topk.py new file mode 100644 index 0000000..91a91c4 --- /dev/null +++ b/tests/test_topk.py @@ -0,0 +1,36 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +def _make_unique_along_dim(input, dim, dtype): + dim_size = input.shape[dim] + view_shape = [1] * input.ndim + view_shape[dim] = dim_size + offset = torch.arange(dim_size, device=input.device, dtype=dtype).view(view_shape) + epsilon = 0.01 if dtype == torch.float16 else 1e-4 + return input + offset * epsilon + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_topk(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + dim = -1 + dim_size = input.shape[dim] + k = random.randint(1, min(dim_size, 8)) + + input = _make_unique_along_dim(input, dim, dtype) + + ninetoothed_values, ninetoothed_indices = ntops.torch.topk(input, k, dim=dim) + reference_values, reference_indices = torch.topk( + input, k, dim=dim, largest=True, sorted=True + ) + + assert torch.allclose(ninetoothed_values, reference_values, rtol=rtol, atol=atol) + assert torch.equal(ninetoothed_indices, reference_indices) diff --git a/tests/test_where.py b/tests/test_where.py new file mode 100644 index 0000000..f4075c3 --- /dev/null +++ b/tests/test_where.py @@ -0,0 +1,19 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_where(shape, dtype, device, rtol, atol): + condition = torch.rand(shape, device=device) > 0.5 + input = torch.randn(shape, dtype=dtype, device=device) + other = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.where(condition, input, other) + reference_output = torch.where(condition, input, other) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)