Skip to content

Commit de1d2df

Browse files
authored
Merge pull request #467 from Xilinx/planzase.rotary-embedding_and_attention_ops
Add support for RotaryEmbedding and Attention ONNX ops
2 parents 23050e2 + 781c3e3 commit de1d2df

File tree

10 files changed

+982
-0
lines changed

10 files changed

+982
-0
lines changed

docs/Dialects/onnx.md

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,91 @@ Effects: `MemoryEffects::Effect{}`
511511
| :----: | ----------- |
512512
| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
513513

514+
### `onnx.Attention` (ONNXAttentionOp)
515+
516+
_ONNX Attention operation_
517+
518+
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed.
519+
520+
This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V.
521+
522+
For self attention, `kv_sequence_length` equals to `q_sequence_length`.
523+
524+
For cross attention, query and key might have different lengths.
525+
526+
This operator also covers the 3 following variants based on the number of heads:
527+
1) Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`.
528+
2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
529+
3) Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`.
530+
531+
Attention bias to be added is calculated based on `attn_mask` input and `is_causal attribute`, only one of which can be provided.
532+
1) If `is_causal` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment.
533+
2) `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score.
534+
535+
Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them.
536+
The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided:
537+
538+
```
539+
The following pattern is applied by this operator:
540+
Q K V
541+
| | |
542+
Q*sqrt(scale) K*sqrt(scale) |
543+
| | |
544+
| Transpose |
545+
| | |
546+
---MatMul--- |
547+
| |
548+
at_mask---Add |
549+
| |
550+
softcap (if provided) |
551+
| |
552+
Softmax |
553+
| |
554+
-----MatMul------
555+
|
556+
Y
557+
```
558+
559+
560+
Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<23>`
561+
562+
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
563+
564+
Effects: `MemoryEffects::Effect{}`
565+
566+
#### Attributes:
567+
568+
<table>
569+
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
570+
<tr><td><code>is_causal</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
571+
<tr><td><code>kv_num_heads</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
572+
<tr><td><code>q_num_heads</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
573+
<tr><td><code>qk_matmul_output_mode</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
574+
<tr><td><code>scale</code></td><td>::mlir::FloatAttr</td><td>32-bit float attribute</td></tr>
575+
<tr><td><code>softcap</code></td><td>::mlir::FloatAttr</td><td>32-bit float attribute</td></tr>
576+
<tr><td><code>softmax_precision</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
577+
</table>
578+
579+
#### Operands:
580+
581+
| Operand | Description |
582+
| :-----: | ----------- |
583+
| `Q` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
584+
| `K` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
585+
| `V` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
586+
| `attn_mask` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 1-bit signless integer values or none type
587+
| `past_key` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
588+
| `past_value` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
589+
590+
#### Results:
591+
592+
| Result | Description |
593+
| :----: | ----------- |
594+
| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
595+
| `present_key` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
596+
| `present_value` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
597+
| `qk_matmul_output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
598+
514599
### `onnx.AveragePool` (ONNXAveragePoolOp)
515600

516601
_ONNX AveragePool operation_
@@ -8575,6 +8660,141 @@ Effects: `MemoryEffects::Effect{}`
85758660
| :----: | ----------- |
85768661
| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
85778662

8663+
### `onnx.RotaryEmbedding` (ONNXRotaryEmbeddingOp)
8664+
8665+
_ONNX RotaryEmbedding operation_
8666+
8667+
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864.
8668+
The key advantage of RoPE is that it allows the model to understand both the absolute position of a token and the relative distances
8669+
between tokens. This is achieved through a rotational mechanism where the extent of rotation is computed based on the token's absolute position (position_ids).
8670+
8671+
The rotational mechanism is defined by sine and cosine functions that are used to represent the rotation angles.
8672+
For each token in the sequence, its positional embedding is computed by rotating its embedding vector. This is done by splitting the
8673+
embedding vector either into two halves or interleaving every alternate token and applying the rotation matrix to each half of the embedding vector.
8674+
The rotation matrix is parameterized by the token's position in the sequence. The rotated halves of the embedding vector are concatenated
8675+
to form the final positional embedding for each token. The rotated positional embeddings are used in the self-attention mechanism.
8676+
The rotation ensures that the model captures both absolute and relative positional information.
8677+
8678+
Rotary embeddings are defined using the following algorithm:
8679+
8680+
```python
8681+
def rotary_embedding(
8682+
input: np.ndarray,
8683+
cos_cache: np.ndarray,
8684+
sin_cache: np.ndarray,
8685+
position_ids: np.ndarray | None = None,
8686+
interleaved=None,
8687+
rotary_embedding_dim=None,
8688+
num_heads=None,
8689+
) -> np.ndarray:
8690+
original_input_shape = input.shape
8691+
# First ensure input to be processed has shape [batch_size, seq_len, num_heads, head_size]
8692+
if len(input.shape) == 4:
8693+
input = np.transpose(input, (0, 2, 1, 3))
8694+
batch_size = input.shape[0]
8695+
sequence_length = input.shape[1]
8696+
if len(input.shape) == 3:
8697+
hidden_size = input.shape[2]
8698+
assert num_heads != 0
8699+
head_size = int(hidden_size / num_heads)
8700+
new_shape = [batch_size, sequence_length, num_heads, head_size]
8701+
input = np.reshape(input, new_shape)
8702+
assert len(input.shape) == 4
8703+
head_size = input.shape[3]
8704+
8705+
# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
8706+
if rotary_embedding_dim is None or rotary_embedding_dim == 0:
8707+
# If rotary_embedding_dim not provided, perform full rotation by using head_size
8708+
rotary_embedding_dim = head_size
8709+
x_rotate = input[:, :, :, :rotary_embedding_dim]
8710+
x_not_rotate = input[:, :, :, rotary_embedding_dim:]
8711+
rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
8712+
8713+
# Retrieve sin and cos caches using position ids
8714+
if position_ids is not None:
8715+
cos_cache = cos_cache[
8716+
position_ids
8717+
] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
8718+
sin_cache = sin_cache[
8719+
position_ids
8720+
] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
8721+
8722+
# Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
8723+
if cos_cache.shape[-1] != rotary_embedding_dim_half:
8724+
raise ValueError(
8725+
f\"Last dimension of cos cache ({cos_cache.shape[-1]}) does not match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).\"
8726+
)
8727+
if sin_cache.shape[-1] != rotary_embedding_dim_half:
8728+
raise ValueError(
8729+
f\"Last dimension of sin cache ({sin_cache.shape[-1]}) does not match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).\"
8730+
)
8731+
8732+
cos_cache = np.expand_dims(
8733+
cos_cache, axis=2
8734+
) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
8735+
sin_cache = np.expand_dims(
8736+
sin_cache, axis=2
8737+
) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
8738+
8739+
# Either divide the input in halves or interleave (based on interleaved attribute)
8740+
if interleaved:
8741+
x1 = x_rotate[:, :, :, 0::2]
8742+
x2 = x_rotate[:, :, :, 1::2]
8743+
else:
8744+
x1, x2 = np.split(x_rotate, 2, axis=-1)
8745+
8746+
# Calculate real and imaginary values
8747+
real = (cos_cache * x1) - (sin_cache * x2)
8748+
imag = (sin_cache * x1) + (cos_cache * x2)
8749+
8750+
# Inserted rotated embeddings back to the original input
8751+
if interleaved:
8752+
# x_rotate[:, :, :, 0::2] = real
8753+
# x_rotate[:, :, :, 1::2] = imag
8754+
real = np.expand_dims(real, axis=-1)
8755+
imag = np.expand_dims(imag, axis=-1)
8756+
x_rotate_concat = np.concatenate((real, imag), axis=-1)
8757+
x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
8758+
else:
8759+
x_rotate = np.concatenate((real, imag), axis=-1)
8760+
output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
8761+
if len(original_input_shape) == 3:
8762+
output = np.reshape(output, original_input_shape)
8763+
else:
8764+
output = np.transpose(output, (0, 2, 1, 3))
8765+
return output
8766+
```
8767+
8768+
Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<23>`
8769+
8770+
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
8771+
8772+
Effects: `MemoryEffects::Effect{}`
8773+
8774+
#### Attributes:
8775+
8776+
<table>
8777+
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
8778+
<tr><td><code>interleaved</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
8779+
<tr><td><code>num_heads</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
8780+
<tr><td><code>rotary_embedding_dim</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
8781+
</table>
8782+
8783+
#### Operands:
8784+
8785+
| Operand | Description |
8786+
| :-----: | ----------- |
8787+
| `X` | tensor of 32-bit float values or tensor of 16-bit float values or tensor of bfloat16 type values
8788+
| `cos_cache` | tensor of 32-bit float values or tensor of 16-bit float values or tensor of bfloat16 type values
8789+
| `sin_cache` | tensor of 32-bit float values or tensor of 16-bit float values or tensor of bfloat16 type values
8790+
| `position_ids` | tensor of 64-bit signless integer values or none type
8791+
8792+
#### Results:
8793+
8794+
| Result | Description |
8795+
| :----: | ----------- |
8796+
| `Y` | tensor of 32-bit float values or tensor of 16-bit float values or tensor of bfloat16 type values
8797+
85788798
### `onnx.Round` (ONNXRoundOp)
85798799

85808800
_ONNX Round operation_

src/Builder/OpBuildTable.inc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ op_dialect_version_map_["Asin"] = {22};
1818
op_dialect_version_map_["Asinh"] = {22};
1919
op_dialect_version_map_["Atan"] = {22};
2020
op_dialect_version_map_["Atanh"] = {22};
21+
op_dialect_version_map_["Attention"] = {23};
2122
op_dialect_version_map_["AveragePool"] = {22};
2223
op_dialect_version_map_["BatchNormalization"] = {15, 9};
2324
op_dialect_version_map_["Bernoulli"] = {22};
@@ -161,6 +162,7 @@ op_dialect_version_map_["Reshape"] = {21};
161162
op_dialect_version_map_["Resize"] = {19, 18, 13, 11, 10};
162163
op_dialect_version_map_["ReverseSequence"] = {10};
163164
op_dialect_version_map_["RoiAlign"] = {22};
165+
op_dialect_version_map_["RotaryEmbedding"] = {23};
164166
op_dialect_version_map_["Round"] = {22};
165167
op_dialect_version_map_["SVMClassifier"] = {1};
166168
op_dialect_version_map_["SVMRegressor"] = {1};
@@ -236,6 +238,8 @@ import_handler_map_["Atan"] =
236238
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAtanOp>;
237239
import_handler_map_["Atanh"] =
238240
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAtanhOp>;
241+
import_handler_map_["Attention"] =
242+
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAttentionOp>;
239243
import_handler_map_["AveragePool"] =
240244
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAveragePoolOp>;
241245
import_handler_map_["BatchNormalization"] =
@@ -552,6 +556,8 @@ import_handler_map_["ReverseSequence"] =
552556
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReverseSequenceOp>;
553557
import_handler_map_["RoiAlign"] =
554558
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRoiAlignOp>;
559+
import_handler_map_["RotaryEmbedding"] =
560+
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRotaryEmbeddingOp>;
555561
import_handler_map_["Round"] =
556562
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRoundOp>;
557563
import_handler_map_["STFT"] =

src/Dialect/ONNX/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,13 @@ add_onnx_mlir_library(OMONNXOps
5858
ONNXOps/Math/Reduction.cpp
5959
ONNXOps/Math/Scatter.cpp
6060
ONNXOps/Math/TopK.cpp
61+
ONNXOps/NN/Attention.cpp
6162
ONNXOps/NN/Conv.cpp
6263
ONNXOps/NN/Dropout.cpp
6364
ONNXOps/NN/Normalization.cpp
6465
ONNXOps/NN/Pooling.cpp
6566
ONNXOps/NN/RoiAlign.cpp
67+
ONNXOps/NN/RotaryEmbedding.cpp
6668
ONNXOps/ObjectDetection/NonMaxSuppression.cpp
6769
ONNXOps/Quantize/DequantizeLinear.cpp
6870
ONNXOps/Quantize/DynamicQuantizeLinear.cpp

0 commit comments

Comments
 (0)