You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
513
513
514
+
### `onnx.Attention` (ONNXAttentionOp)
515
+
516
+
_ONNX Attention operation_
517
+
518
+
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed.
519
+
520
+
This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V.
521
+
522
+
For self attention, `kv_sequence_length` equals to `q_sequence_length`.
523
+
524
+
For cross attention, query and key might have different lengths.
525
+
526
+
This operator also covers the 3 following variants based on the number of heads:
527
+
1) Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`.
528
+
2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
529
+
3) Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`.
530
+
531
+
Attention bias to be added is calculated based on `attn_mask` input and `is_causal attribute`, only one of which can be provided.
532
+
1) If `is_causal` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment.
533
+
2) `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score.
534
+
535
+
Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them.
536
+
The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided:
537
+
538
+
```
539
+
The following pattern is applied by this operator:
<tr><td><code>softmax_precision</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
577
+
</table>
578
+
579
+
#### Operands:
580
+
581
+
| Operand | Description |
582
+
| :-----: | ----------- |
583
+
| `Q` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
584
+
| `K` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
585
+
| `V` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
586
+
| `attn_mask` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 1-bit signless integer values or none type
587
+
| `past_key` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
588
+
| `past_value` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
589
+
590
+
#### Results:
591
+
592
+
| Result | Description |
593
+
| :----: | ----------- |
594
+
| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
595
+
| `present_key` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
596
+
| `present_value` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
597
+
| `qk_matmul_output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864.
8668
+
The key advantage of RoPE is that it allows the model to understand both the absolute position of a token and the relative distances
8669
+
between tokens. This is achieved through a rotational mechanism where the extent of rotation is computed based on the token's absolute position (position_ids).
8670
+
8671
+
The rotational mechanism is defined by sine and cosine functions that are used to represent the rotation angles.
8672
+
For each token in the sequence, its positional embedding is computed by rotating its embedding vector. This is done by splitting the
8673
+
embedding vector either into two halves or interleaving every alternate token and applying the rotation matrix to each half of the embedding vector.
8674
+
The rotation matrix is parameterized by the token's position in the sequence. The rotated halves of the embedding vector are concatenated
8675
+
to form the final positional embedding for each token. The rotated positional embeddings are used in the self-attention mechanism.
8676
+
The rotation ensures that the model captures both absolute and relative positional information.
8677
+
8678
+
Rotary embeddings are defined using the following algorithm:
8679
+
8680
+
```python
8681
+
def rotary_embedding(
8682
+
input: np.ndarray,
8683
+
cos_cache: np.ndarray,
8684
+
sin_cache: np.ndarray,
8685
+
position_ids: np.ndarray | None = None,
8686
+
interleaved=None,
8687
+
rotary_embedding_dim=None,
8688
+
num_heads=None,
8689
+
) -> np.ndarray:
8690
+
original_input_shape = input.shape
8691
+
# First ensure input to be processed has shape [batch_size, seq_len, num_heads, head_size]
0 commit comments