Skip to content

Commit 87e1f76

Browse files
[ONNX] Add Onnx->Torch lowering for GroupQueryAttention op (#4006)
This commit adds the lowering for Onnx's GroupQueryAttention op. The lowering is adopted from here: https://github.com/microsoft/onnxruntime/blob/65008cbb7393b121400a40dd8af4cc93d506918f/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L45 https://github.com/microsoft/onnxruntime/blob/65008cbb7393b121400a40dd8af4cc93d506918f/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h#L50 The reference implementation or pseudo-code can be viewed here: https://gist.github.com/vivekkhandelwal1/f307b687fb133f36276f3d1a3c60ed7e. The lowering supports the GQA with rotary_embedding. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 40b3469 commit 87e1f76

File tree

2 files changed

+513
-0
lines changed

2 files changed

+513
-0
lines changed

0 commit comments

Comments
 (0)