|
| 1 | + |
| 2 | +# `EmbeddingBackward` |
| 3 | + |
| 4 | +`EmbeddingBackward`,即[**嵌入**算子](/infiniop/ops/embedding/README.md)的反向算子。用于训练大模型的词嵌入和加性位置嵌入。 |
| 5 | + |
| 6 | +`EmbeddingBackward` 算子支持 1 个或 2 个相同的步骤,根据“号码”从将输出的梯度叠加到嵌入表的梯度,其公式表述为: |
| 7 | + |
| 8 | +$$ \begin{equation} d_{table1} = \alpha_1 \cdot dy[i_1] \end{equation} $$ |
| 9 | + |
| 10 | +$$ \begin{equation} d_{table2} = \alpha_2 \cdot dy[i_2] \end{equation} $$ |
| 11 | + |
| 12 | +- 通常 $α$ 为 1; |
| 13 | +- $table2$ 可以不使用,则公式 $(2)$ 不存在; |
| 14 | + |
| 15 | +## 接口 |
| 16 | + |
| 17 | +### 计算 |
| 18 | + |
| 19 | +```c |
| 20 | +infiniStatus_t infiniopEmbeddingBackward( |
| 21 | + infiniopEmbeddingBackwardDescriptor_t desc, |
| 22 | + void *dtable1, |
| 23 | + void *dtable2, |
| 24 | + const void *dy, |
| 25 | + const void *i1, |
| 26 | + const void *i2, |
| 27 | + void *stream |
| 28 | +); |
| 29 | +``` |
| 30 | + |
| 31 | +<div style="background-color: lightblue; padding: 1px;"> 参数: </div> |
| 32 | + |
| 33 | +- `desc`: |
| 34 | + 已使用 `infiniopEmbeddingBackwardDescriptor_t()` 初始化的算子描述符; |
| 35 | +- `dtable1`: |
| 36 | + 第 1 个嵌入表的梯度; |
| 37 | +- `dtable2`: |
| 38 | + 第 2 个嵌入表的梯度,不使用则为空; |
| 39 | +- `dy`: |
| 40 | + 输出结果的梯度; |
| 41 | +- `i1`: |
| 42 | + 第 1 个嵌入序号; |
| 43 | +- `i2`: |
| 44 | + 第 2 个嵌入序号,不使用则为空; |
| 45 | +- `stream`: |
| 46 | + 计算流/队列; |
| 47 | + |
| 48 | +<div style="background-color: lightblue; padding: 1px;"> 返回值:</div> |
| 49 | + |
| 50 | +- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_BAD_DEVICE`], [`INFINI_STATUS_EXECUTION_FAILED`]. |
| 51 | + |
| 52 | +### 创建算子描述 |
| 53 | + |
| 54 | +```c |
| 55 | +infiniStatus_t infiniopCreateEmbeddingBackwardDescriptor( |
| 56 | + infiniopHandle_t handle, |
| 57 | + infiniopEmbeddingBackwardDescriptor_t *desc_ptr, |
| 58 | + infiniopTensorDescriptor_t dtable1_desc, |
| 59 | + infiniopTensorDescriptor_t dtable2_desc, |
| 60 | + infiniopTensorDescriptor_t dy_desc, |
| 61 | + infiniopTensorDescriptor_t i1_desc, |
| 62 | + infiniopTensorDescriptor_t i2_desc, |
| 63 | + float alpha1, |
| 64 | + float alpha2, |
| 65 | + char dtable1_acc, |
| 66 | + char dtable2_acc |
| 67 | +); |
| 68 | +``` |
| 69 | + |
| 70 | +<div style="background-color: lightblue; padding: 1px;"> 参数:</div> |
| 71 | + |
| 72 | +- `handle`: |
| 73 | + `infiniopHandle_t` 类型的硬件控柄。详情请看:[`InfiniopHandle_t`] |
| 74 | +- `desc_ptr`: |
| 75 | + `infiniopCreateEmbeddingBackwardDescriptor` 指针,指向将被初始化的算子描述符地址; |
| 76 | +- `dtable1_desc` - $\{ dT | (N1, D) | (..., 1) \}$: |
| 77 | + 算子输入 `table1` 的张量描述; |
| 78 | +- `dtable2_desc` - $\{ dT | (N2, D) | (..., 1) \}$: |
| 79 | + 算子输入 `table2` 的张量描述; |
| 80 | +- `dy_desc` - $\{ dT | (N, D) | (..., 1) \}$: |
| 81 | + 算子输出 `y` 的张量描述; |
| 82 | +- `i1_desc` - $\{ dI | (N) | (1) \}$: |
| 83 | + 算子输入 `i1` 的张量描述; |
| 84 | +- `i2_desc` - $\{ dI | (N) | (1) \}$: |
| 85 | + 算子输入 `i2` 的张量描述,为空表示不使用 $ table2 $, `alpha2` 必须同时为 0; |
| 86 | +- `alpha1` - float: |
| 87 | + 第 1 项嵌入的缩放因子; |
| 88 | +- `alpha2` - float: |
| 89 | + 第 2 项嵌入的缩放因子,取 0 表示不使用 $ table2 $,`i2_desc` 必须同时为空; |
| 90 | +- `dtable1_acc` - char: |
| 91 | + 第 1 项嵌入是否叠加梯度,0 表示不叠加; |
| 92 | +- `dtable2_acc` - float: |
| 93 | + 第 2 项嵌入是否叠加梯度,0 表示不叠加; |
| 94 | + |
| 95 | +<div style="background-color: lightblue; padding: 1px;"> 参数限制:</div> |
| 96 | + |
| 97 | +- $dT$: 任意代数类型; |
| 98 | +- $dT_i$: 任意整型; |
| 99 | + |
| 100 | +<div style="background-color: lightblue; padding: 1px;"> 返回值:</div> |
| 101 | + |
| 102 | +- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_BAD_TENSOR_SHAPE`], [`INFINI_STATUS_BAD_TENSOR_DTYPE`], [`INFINI_STATUS_BAD_TENSOR_STRIDES`], [`INFINI_STATUS_BAD_DEVICE`]. |
| 103 | + |
| 104 | +### 销毁算子描述符 |
| 105 | + |
| 106 | +```c |
| 107 | +infiniStatus_t infiniopDestroyEmbeddingBackwardDescriptor( |
| 108 | + infiniopEmbeddingBackwardDescriptor_t desc |
| 109 | +); |
| 110 | +``` |
| 111 | + |
| 112 | +<div style="background-color: lightblue; padding: 1px;"> 参数: </div> |
| 113 | + |
| 114 | +- `desc`: |
| 115 | + 输入。待销毁的算子描述符; |
| 116 | + |
| 117 | +<div style="background-color: lightblue; padding: 1px;"> 返回值: </div> |
| 118 | + |
| 119 | +- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_DEVICE`]. |
| 120 | + |
| 121 | +<!-- 链接 --> |
| 122 | +[`InfiniopHandle_t`]: /infiniop/handle/README.md |
| 123 | + |
| 124 | +[`INFINI_STATUS_SUCCESS`]:/common/status/README.md#INFINI_STATUS_SUCCESS |
| 125 | +[`INFINI_STATUS_BAD_PARAM`]:/common/status/README.md#INFINI_STATUS_BAD_PARAM |
| 126 | +[`INFINI_STATUS_BAD_DEVICE`]:/common/status/README.md#INFINI_STATUS_BAD_DEVICE |
| 127 | +[`INFINI_STATUS_EXECUTION_FAILED`]:/common/status/README.md#INFINI_STATUS_EXECUTION_FAILED |
| 128 | +[`INFINI_STATUS_BAD_TENSOR_SHAPE`]:/common/status/README.md#INFINI_STATUS_BAD_TENSOR_SHAPE |
| 129 | +[`INFINI_STATUS_BAD_TENSOR_DTYPE`]:/common/status/README.md#INFINI_STATUS_BAD_TENSOR_DTYPE |
| 130 | +[`INFINI_STATUS_BAD_TENSOR_STRIDES`]:/common/status/README.md#INFINI_STATUS_BAD_TENSOR_STRIDES |
0 commit comments