Skip to content

Commit 271ca37

Browse files
Merge pull request #48 from GushanFall/T1-3-1
[T1-3-1] GushanFall
2 parents 519eceb + 365610d commit 271ca37

File tree

2 files changed

+366
-0
lines changed

2 files changed

+366
-0
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# `Flash Attention`
2+
3+
`Flash Attention`,是一种高效的**自注意力机制实现**,在加速注意力计算的同时也减少了内存占用。其核心原理是将输入分块,在每个块上分别进行注意力计算,从而减少对 HBM 的读写次数以提高计算效率。
4+
5+
标准 Attention 的计算为:
6+
1. 计算 $QK^T$,得到初步的注意力分数 attention_score;
7+
2. 添加位置编码以区别不同位置的元素,并乘以缩放系数;
8+
3. 根据注意力掩码屏蔽对应位置的元素,得到 masked_attention_score;
9+
4. 最后经过 softmax 计算,并与 $V$ 相乘,得到最终的注意力分数 attention_out。
10+
11+
从中可以发现,若 SRAM 无法存储完整的计算数据,则计算过程中需要频繁访问 HBM,I/O 需求为 $O(Nd+N^2)$。此前,由于 softmax 的计算需要遍历整个序列以得到整体最大值,所以优化的方式往往是通过近似的方式来降低数据占用的空间。
12+
13+
而 FlashAttention 给出了 softmax 的分块计算方式,以此代替原公式中的 softmax 计算,过程如下:
14+
1. 提前对输入的 QKV 进行分块(分别按照 QKV 的 sequence length 方向切分);
15+
2. 增量计算分块 softmax ,并维护两个全局变量
16+
- $m$:已处理块的最大值
17+
- $\ell$:已处理块的指数和
18+
19+
以计算 attention_out 中的第 j 块为例,需遍历 $Q$ 切分后的每一块
20+
1. $Q_0$ 参与计算时,正常计算,并将计算结果直接保存在 attention_out[j] 中,并更新 $m$ 与 $\ell$;
21+
2. 此后, $Q_i$ 参与计算时,先统计当前块 masked_attention_score 的最大值 $m_i$,并计算其指数和 $\ell_i$;
22+
```
23+
P = e^{masked_attention_score - m_i}
24+
l_i = rowsum(P)
25+
```
26+
3. 更新 $m_{new}$ 为 $\max(m_i,m)$ ,并根据新的最大值计算 $\ell_{new}$,再由此计算新的 attention_out[j];
27+
```
28+
l_new = e^{m-m_new}*l+e^{m_i-m_new}
29+
attention_out = (attention_out * l * e^{m-m_new} + e^{m_i-m_new} * P * V_j) / l_new
30+
```
31+
4. 分别用 $m_{new}$ 与 $\ell_{max}$ 更新 $m$ 与 $\ell$,开启新一块的计算;
32+
5. 遍历完 $Q$ 之后,即可得到完整的 attention_out[j]
33+
34+
FlashAttention-2 则在 FlashAttention 的基础上进行了两点改进:
35+
36+
1. 在计算局部 Attention 时,先不考虑分母的指数和,而是 $\mathbf{O}_{i}^{(j)}=\mathrm{diag}(e^{m_{i}^{(j-1)}-m_{i}^{(j)}})\mathbf{O}_{i}^{(j-1)}+\tilde{\mathbf{P}}_{i}^{(j)}\mathbf{V}_{j}$
37+
2. 在最后一步再带入计算,得到正确的结果 $\mathbf{O}_{i}=\mathrm{diag}(\ell_{i}^{(T_{c})})^{-1}\mathbf{O}_{i}^{(T_{c})}$
38+
39+
同时,FlashAttention-2 将 $Q$ 移到了外循环,而 $K$ 移到了内循环。由于改进了算法,使得 warps 之间不需要相互通信,所以外循环可以放在不同的 thread block 上。
40+
41+
42+
## 接口
43+
44+
### 计算
45+
46+
```c
47+
infiniStatus_t infiniopFlashAttention(
48+
infiniopFlashAttentionDescriptor_t desc,
49+
void *workspace,
50+
size_t workspace_size,
51+
void *out,
52+
void *l,
53+
const void *q,
54+
const void *k,
55+
const void *v,
56+
const void *mask,
57+
void *stream
58+
);
59+
```
60+
61+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
62+
63+
- `desc`:
64+
已使用 `infiniopCreateFlashAttentionDescriptor()` 初始化的算子描述符;
65+
- `workspace`:
66+
指向算子计算所需的额外工作空间;
67+
- `workspace_size`:
68+
`workspace` 的大小,单位:字节;
69+
- `out`:
70+
注意力计算结果地址。张量限制见[创建算子描述](#创建算子描述)部分。
71+
- `l`:
72+
logsumexp,用于反向传播计算的暂存结果。张量限制见[创建算子描述](#创建算子描述)部分。
73+
- `q`:
74+
查询(Query)张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
75+
- `k`:
76+
键(Key)张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
77+
- `v`:
78+
值(Value)张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
79+
- `mask`:
80+
注意力掩码的数据指针,可选参数。取值为 `0` 时表示保留对应位置的元素(参与计算);取值为 `-inf` (负无穷)时表示屏蔽对应位置的元素(即跳过,不参与计算)。张量限制见[创建算子描述](#创建算子描述)部分。
81+
- `stream`:
82+
计算流/队列。
83+
84+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
85+
86+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_NULL_POINTER`], [`INFINI_STATUS_INSUFFICIENT_WORKSPACE`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`], [`INFINI_STATUS_INTERNAL_ERROR`].
87+
88+
### 创建算子描述
89+
90+
```c
91+
infiniStatus_t infiniopCreateFlashAttentionDescriptor(
92+
infiniopHandle_t handle,
93+
infiniopFlashAttentionDescriptor_t *desc_ptr,
94+
infiniopTensorDescriptor_t out_desc,
95+
infiniopTensorDescriptor_t l_desc,
96+
infiniopTensorDescriptor_t q_desc,
97+
infiniopTensorDescriptor_t k_desc,
98+
infiniopTensorDescriptor_t v_desc,
99+
infiniopTensorDescriptor_t mask_desc,
100+
infiniopAttentionMaskType_t mask_type
101+
);
102+
```
103+
104+
<div style="background-color: lightblue; padding: 1px;"> 参数:</div>
105+
106+
- `handle`:
107+
`infiniopHandle_t`类型的硬件控柄。详见 [`InfiniopHandle_t`]
108+
- `desc_ptr`:
109+
`infiniopFlashAttentionDescriptor_t` 指针,指向将被初始化的算子描述符地址。
110+
- `out_desc` - { dT | ((batch_size,) seq_len_q, num_heads_q, head_dim) | ($\ldots, 1$)}:
111+
算子计算参数 `out` 的张量描述,四维或者三维,最后一维连续。
112+
- `l_desc` - { dT | ((batch_szie,) seq_len_q, num_heads_q) | ($\ldots, 1$)};
113+
算子计算参数 `l` 的张量描述,三维或者二维。
114+
- `q_desc` - { dT | ((batch_size,) seq_len_q, num_heads_q, head_dim) | ($\ldots, 1$)}:
115+
算子计算参数 `q` 的张量描述,形状与 `out_desc` 一致,最后一维连续。
116+
- `k_desc` - { dT | ((batch_size,) seq_len_kv, num_heads_kv, head_dim) | ($\ldots, 1$)}:
117+
算子计算参数 `k` 的张量描述,最后一维连续。
118+
- `v_desc` - { dT | ((batch_size,) seq_len_kv, num_heads_kv, head_dim) | ($\ldots, 1$)}:
119+
算子计算参数 `v` 的张量描述,形状与 `k_desc` 一致,最后一维连续。
120+
- `mask_desc` - { dM | (seq_len_q, seq_len_kv) | (~)}:
121+
算子计算参数 `mask` 的张量描述,当 `mask_type=INFINIOP_ATTENTION_MASK_TYPE_FULL` 时,`mask` 不可为空,其余情况 `mask` 可为`nullptr`
122+
- `mask_type` - `infiniopAttentionMaskType_t`:
123+
注意力类型参数,有三种类型可选,详细见参数限制。
124+
125+
参数限制:
126+
127+
- `dT`: `Float16`, `Float32``BFloat16`
128+
- `dM`: `Flaot32`
129+
- `seq_len_q``seq_len_kv` 可以不同。
130+
- `num_heads_q``num_heads_kv` 可以不同,但需满足前者是后者的整数倍(非0整数)。
131+
- 当 $N_q/N_{kv}=1$ 时,即为 MQA (multi-query attention)
132+
- 当 $N_q/N_{kv}>1$ 时,即为 GQA (grouped-query attention)
133+
- `mask_type` 的三种类型:
134+
- `INFINIOP_ATTENTION_MASK_TYPE_NONE=0`: 不使用注意力掩码,忽略 `mask` 取值;
135+
- `INFINIOP_ATTENTION_MASK_TYPE_FULL=1`: 使用完整 mask 矩阵,此时 `mask` 不可为空;
136+
- `INFINIOP_ATTENTION_MASK_TYPE_CAUSAL=2`: 使用标准因果掩码,对应以左上顶点划分的下三角场景,忽略 `mask` 取值;
137+
138+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
139+
140+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_BAD_TENSOR_SHAPE`], [`INFINI_STATUS_BAD_TENSOR_DTYPE`], [`INFINI_STATUS_BAD_TENSOR_STRIDES`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`].
141+
142+
### 计算额外工作空间
143+
144+
```c
145+
infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
146+
infiniopFlashAttentionDescriptor_t desc,
147+
size_t *size
148+
);
149+
```
150+
151+
<div style="background-color: lightblue; padding: 1px;"> 参数:</div>
152+
153+
- `desc`:
154+
已使用 `infiniopCreateFlashAttentionDescriptor()` 初始化的算子描述符;
155+
- `size`:
156+
额外空间大小的计算结果的写入地址;
157+
158+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
159+
160+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`].
161+
162+
### 销毁算子描述符
163+
164+
```c
165+
infiniStatus_t infiniopDestoryFlashAttentionDescriptor(
166+
infiniopFlashAttentionDescriptor_t desc
167+
);
168+
```
169+
170+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
171+
172+
- `desc`:
173+
输入。 待销毁的算子描述符;
174+
175+
<div style="background-color: lightblue; padding: 1px;"> 返回值: </div>
176+
177+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`].
178+
179+
<!-- 链接 -->
180+
[`InfiniopHandle_t`]: /infiniop/handle/README.md
181+
182+
[`INFINI_STATUS_SUCCESS`]: /common/status/README.md#INFINI_STATUS_SUCCESS
183+
[`INFINI_STATUS_BAD_PARAM`]: /common/status/README.md#INFINI_STATUS_BAD_PARAM
184+
[`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`]: /common/status/README.md#INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
185+
[`INFINI_STATUS_BAD_TENSOR_SHAPE`]: /common/status/README.md#INFINI_STATUS_BAD_TENSOR_SHAPE
186+
[`INFINI_STATUS_BAD_TENSOR_DTYPE`]: /common/status/README.md#INFINI_STATUS_BAD_TENSOR_DTYPE
187+
[`INFINI_STATUS_BAD_TENSOR_STRIDES`]: /common/status/README.md#INFINI_STATUS_BAD_TENSOR_STRIDES
188+
[`INFINI_STATUS_NULL_POINTER`]:/common/status/README.md#INFINI_STATUS_NULL_POINTER
189+
[`INFINI_STATUS_INSUFFICIENT_WORKSPACE`]:/common/status/README.md#INFINI_STATUS_INSUFFICIENT_WORKSPACE
190+
[`INFINI_STATUS_INTERNAL_ERROR`]:/common/status/README.md#INFINI_STATUS_INTERNAL_ERROR
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# `FlashAttentionBackward`
2+
3+
`FlashAttentionBackward` 是算子 `FlashAttention` 的反向传播。
4+
5+
与正向传播相同,先进行分块。 $K$ 和 $V$ 在外循环中逐块加载,而 $Q$ 在内循环中逐块加载。
6+
7+
1. 外循环中,每次循环需要先初始化当前块的 $\mathbf{d}\mathbf{K}_j$ 和 $\mathbf{d}\mathbf{V}_j$ 为 0;
8+
2. 内循环中,按以下顺序计算
9+
1. $\mathbf{S}_i^{(j)}=\mathbf{Q}_i\mathbf{K}_j^T\in\mathbb{R}^{B_r\times B_c}$
10+
2. $\mathbf{P}_i^{(j)}=\exp(\mathbf{S} _{ij}-L_i)\in\mathbb{R}^{B_r\times B_c}$
11+
3. $\mathbf{d}\mathbf{V}_j\leftarrow\mathbf{d}\mathbf{V}_j+(\mathbf{P}_i^{(j)})^\top\mathbf{d}\mathbf{O}_i\in\mathbb{R}^{B_c\times d}$
12+
4. $\mathbf{dP}_i^{(j)}=\mathbf{dO}_i\mathbf{V}_j^\top\in\mathbb{R}^{B_r\times B_c}$
13+
5. $\mathbf{dQ}_i\leftarrow\mathbf{dQ}_i+\mathbf{dS}_i^{(j)}\mathbf{K}_j$
14+
6. $\mathbf{dK}_j\leftarrow\mathbf{dK}_j+\mathbf{dS}_i^{(j)\top}\mathbf{Q}_i\in\mathbb{R}^{B_c\times d}$
15+
16+
## 接口
17+
18+
### 计算
19+
20+
```c
21+
infiniStatus_t infiniopFlashAttentionBackward(
22+
infiniopFlashAttentionBackwardDescriptor_t desc,
23+
void *workspace,
24+
size_t workspace_size,
25+
void *grad_q,
26+
void *grad_k,
27+
void *grad_v,
28+
const void *q,
29+
const void *k,
30+
const void *v,
31+
const void *grad_out,
32+
const void *mask,
33+
void *stream
34+
);
35+
```
36+
37+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
38+
39+
- `desc`:
40+
已使用 `infiniopCreateFlashAttentionBackwardDescriptor()` 初始化的算子描述符;
41+
- `workspace`:
42+
指向算子计算所需的额外工作空间;
43+
- `workspace_size`:
44+
`workspace` 的大小,单位:字节;
45+
- `grad_q`:
46+
查询(Query)梯度计算结果地址。张量限制见[创建算子描述](#创建算子描述)部分。
47+
- `grad_k`:
48+
键(Key)梯度计算结果地址。张量限制见[创建算子描述](#创建算子描述)部分。
49+
- `grad_v`:
50+
值(Value)梯度计算结果地址。张量限制见[创建算子描述](#创建算子描述)部分。
51+
- `q`:
52+
查询(Query)张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
53+
- `k`:
54+
键(Key)张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
55+
- `v`:
56+
值(Value)张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
57+
- `grad_out`
58+
注意力梯度张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
59+
- `mask`:
60+
注意力掩码的数据指针,可选参数。取值为 `0` 时表示保留对应位置的元素(参与计算);取值为 `-inf` (负无穷)时表示屏蔽对应位置的元素(即跳过,不参与计算)。张量限制见[创建算子描述](#创建算子描述)部分。
61+
- `stream`:
62+
计算流/队列。
63+
64+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
65+
66+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_NULL_POINTER`], [`INFINI_STATUS_INSUFFICIENT_WORKSPACE`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`], [`INFINI_STATUS_INTERNAL_ERROR`].
67+
68+
### 创建算子描述
69+
70+
```c
71+
infiniStatus_t infiniopCreateFlashAttentionBackwardDescriptor(
72+
infiniopHandle_t handle,
73+
infiniopFlashAttentionBackwardDescriptor_t *desc_ptr,
74+
infiniopTensorDescriptor_t grad_q_desc,
75+
infiniopTensorDescriptor_t grad_k_desc,
76+
infiniopTensorDescriptor_t grad_v_desc,
77+
infiniopTensorDescriptor_t q_desc,
78+
infiniopTensorDescriptor_t k_desc,
79+
infiniopTensorDescriptor_t v_desc,
80+
infiniopTensorDescriptor_t grad_out_desc,
81+
infiniopTensorDescriptor_t mask_desc,
82+
infiniopAttentionMaskType_t mask_type
83+
);
84+
```
85+
86+
<div style="background-color: lightblue; padding: 1px;"> 参数:</div>
87+
88+
- `handle`:
89+
`infiniopHandle_t`类型的硬件控柄。详见 [`InfiniopHandle_t`]
90+
- `desc_ptr`:
91+
`infiniopFlashAttentionBackwardDescriptor_t` 指针,指向将被初始化的算子描述符地址。
92+
- `grad_q_desc` - { dT | ((batch_size,) seq_len_q, num_heads_q, head_dim) | ($\ldots, 1$)}:
93+
算子计算参数 `grad_q` 的张量描述,四维或者三维,最后一维连续。
94+
- `grad_k_desc` - { dT | ((batch_size,) seq_len_kv, num_heads_kv, head_dim) | ($\ldots, 1$)}:
95+
算子计算参数 `grad_k` 的张量描述,四维或者三维,最后一维连续。
96+
- `grad_v_desc` - { dT | ((batch_size,) seq_len_kv, num_heads_kv, head_dim) | ($\ldots, 1$)}:
97+
算子计算参数 `grad_v` 的张量描述,四维或者三维,最后一维连续。
98+
- `q_desc` - { dT | ((batch_size,) seq_len_q, num_heads_q, head_dim) | ($\ldots, 1$)}:
99+
算子计算参数 `q` 的张量描述,形状与 `grad_q_desc` 一致,最后一维连续。
100+
- `k_desc` - { dT | ((batch_size,) seq_len_kv, num_heads_kv, head_dim) | ($\ldots, 1$)}:
101+
算子计算参数 `k` 的张量描述,形状与 `grad_k_desc` 一致,最后一维连续。
102+
- `v_desc` - { dT | ((batch_size,) seq_len_kv, num_heads_kv, head_dim) | ($\ldots, 1$)}:
103+
算子计算参数 `v` 的张量描述,形状与 `grad_v_desc` 一致,最后一维连续。
104+
- `grad_out` - { dT | ((batch_size,) seq_len_q, num_heads_q, head_dim) | ($\ldots, 1$)}:
105+
算子计算参数 `grad_out` 的张量描述,形状与 `grad_q_desc` 一致,最后一维连续。
106+
- `mask_desc` - { dM | (seq_len_q, seq_len_kv) | (~)}:
107+
算子计算参数 `mask` 的张量描述,当 `mask_type=INFINIOP_ATTENTION_MASK_TYPE_FULL` 时,`mask` 不可为空,其余情况 `mask` 可为`nullptr`
108+
- `mask_type` - `infiniopAttentionMaskType_t`:
109+
注意力类型参数,有三种类型可选,详细见参数限制。
110+
111+
参数限制:
112+
113+
- `dT`: `Float16`, `Float32``BFloat16`
114+
- `dM`: `Flaot32`
115+
- `seq_len_q``seq_len_kv` 可以不同。
116+
- `num_heads_q``num_heads_kv` 可以不同,但需满足前者是后者的整数倍(非0整数)。
117+
- 当 $N_q/N_{kv}=1$ 时,即为 MQA (multi-query attention)
118+
- 当 $N_q/N_{kv}>1$ 时,即为 GQA (grouped-query attention)
119+
- `mask_type` 的三种类型:
120+
- `INFINIOP_ATTENTION_MASK_TYPE_NONE=0`: 不使用注意力掩码,忽略 `mask` 取值;
121+
- `INFINIOP_ATTENTION_MASK_TYPE_FULL=1`: 使用完整 mask 矩阵,此时 `mask` 不可为空;
122+
- `INFINIOP_ATTENTION_MASK_TYPE_CAUSAL=2`: 使用标准因果掩码,对应以左上顶点划分的下三角场景,忽略 `mask` 取值;
123+
124+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
125+
126+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_BAD_TENSOR_SHAPE`], [`INFINI_STATUS_BAD_TENSOR_DTYPE`], [`INFINI_STATUS_BAD_TENSOR_STRIDES`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`].
127+
128+
### 计算额外工作空间
129+
130+
```c
131+
infiniStatus_t infiniopGetFlashAttentionBackwardWorkspaceSize(
132+
infiniopFlashAttentionBackwardDescriptor_t desc,
133+
size_t *size
134+
);
135+
```
136+
137+
<div style="background-color: lightblue; padding: 1px;"> 参数:</div>
138+
139+
- `desc`:
140+
已使用 `infiniopCreateFlashAttentionBackwardDescriptor()` 初始化的算子描述符;
141+
- `size`:
142+
额外空间大小的计算结果的写入地址;
143+
144+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
145+
146+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`].
147+
148+
### 销毁算子描述符
149+
150+
```c
151+
infiniStatus_t infiniopDestoryFlashAttentionBackwardDescriptor(
152+
infiniopFlashAttentionBackwardDescriptor_t desc
153+
);
154+
```
155+
156+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
157+
158+
- `desc`:
159+
输入。 待销毁的算子描述符;
160+
161+
<div style="background-color: lightblue; padding: 1px;"> 返回值: </div>
162+
163+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`].
164+
165+
<!-- 链接 -->
166+
[`InfiniopHandle_t`]: /infiniop/handle/README.md
167+
168+
[`INFINI_STATUS_SUCCESS`]: /common/status/README.md#INFINI_STATUS_SUCCESS
169+
[`INFINI_STATUS_BAD_PARAM`]: /common/status/README.md#INFINI_STATUS_BAD_PARAM
170+
[`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`]: /common/status/README.md#INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
171+
[`INFINI_STATUS_BAD_TENSOR_SHAPE`]: /common/status/README.md#INFINI_STATUS_BAD_TENSOR_SHAPE
172+
[`INFINI_STATUS_BAD_TENSOR_DTYPE`]: /common/status/README.md#INFINI_STATUS_BAD_TENSOR_DTYPE
173+
[`INFINI_STATUS_BAD_TENSOR_STRIDES`]: /common/status/README.md#INFINI_STATUS_BAD_TENSOR_STRIDES
174+
[`INFINI_STATUS_NULL_POINTER`]:/common/status/README.md#INFINI_STATUS_NULL_POINTER
175+
[`INFINI_STATUS_INSUFFICIENT_WORKSPACE`]:/common/status/README.md#INFINI_STATUS_INSUFFICIENT_WORKSPACE
176+
[`INFINI_STATUS_INTERNAL_ERROR`]:/common/status/README.md#INFINI_STATUS_INTERNAL_ERROR

0 commit comments

Comments
 (0)