GroupQueryAttention strange behaviour depending on seq_len/total_seq_len #22732
Unanswered
ManelSemidynamics
asked this question in
API Q&A
Replies: 1 comment 1 reply
-
I think this behaviour due to the CPU provider using a share_buffer for past_{key/value} and present_{key/value}. Is there any way to disable it, so it prints the proper tensor as output? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am implementing my own version of GQA and I found out that if seq_len != 1, total_seq_len != 2, the output "present_key" starting values are not the same as "past_key". Why is that? I expected it to be the concatenation (taking into account tensor formats) of "key" (may need to apply RoPe before concatenation) and "past_key".(I found that seq_len=1, total_seq_len=2 gives the output I expect)
Code to reproduce:
`import onnx
from onnx import helper, TensorProto
from collections import OrderedDict
kv_nh = 2
nh = 4
max_seq_len = 32
inputs = OrderedDict([
('query', helper.make_tensor_value_info('query', TensorProto.FLOAT16, ['batch_size', 'sequence_length', nh128])),
('key', helper.make_tensor_value_info('key', TensorProto.FLOAT16, ['batch_size', 'sequence_length', kv_nh128])),
('value', helper.make_tensor_value_info('value', TensorProto.FLOAT16, ['batch_size', 'sequence_length', kv_nh*128])),
('past_key', helper.make_tensor_value_info('past_key', TensorProto.FLOAT16, ['batch_size', kv_nh, 'past_sequence_length', 128])),
('past_value', helper.make_tensor_value_info('past_value', TensorProto.FLOAT16, ['batch_size', kv_nh, 'past_sequence_length', 128])),
('seqlens_k', helper.make_tensor_value_info('seqlens_k', TensorProto.INT32, ['batch_size', 1])),
('total_sequence_length', helper.make_tensor_value_info('total_sequence_length', TensorProto.INT32, [])),
('cos_cache', helper.make_tensor_value_info('cos_cache', TensorProto.FLOAT16, [max_seq_len, 64])),
('sin_cache', helper.make_tensor_value_info('sin_cache', TensorProto.FLOAT16, [max_seq_len, 64]))
])
outputs = OrderedDict([
('output', helper.make_tensor_value_info('output', TensorProto.FLOAT16, ['batch_size', 'sequence_length', nh*128])),
('present_key', helper.make_tensor_value_info('present_key', TensorProto.FLOAT16, ['batch_size', kv_nh, 'total_sequence_length', 128])),
('present_value', helper.make_tensor_value_info('present_value', TensorProto.FLOAT16, ['batch_size', kv_nh, 'total_sequence_length', 128]))
])
node = helper.make_node(
'GroupQueryAttention',
inputs=list(inputs.keys()),
outputs=list(outputs.keys()),
name='GroupQueryAttention_Node',
domain='com.microsoft',
do_rotary=1,
kv_num_heads=kv_nh,
num_heads=nh,
rotary_interleaved=1,
scale=0.0888383461536163
)
graph = helper.make_graph(
[node],
'GroupQueryAttentionGraph',
list(inputs.values()),
list(outputs.values())
)
model = helper.make_model(
graph,
producer_name='onnx-helper',
opset_imports=[
helper.make_opsetid("", 19),
helper.make_opsetid("com.microsoft", 1)
]
)
onnx.checker.check_model(model)
onnx.save(model, 'group_query_attention.onnx')
import onnxruntime as ort
import numpy as np
session = ort.InferenceSession('/nfs/workspaces/mpiera/GQA/group_query_attention.onnx')
batch = 1
seq_len = 2
tot_seq_len = 4
past = tot_seq_len - seq_len
query = np.random.rand(batch,seq_len,nh128).astype(np.float16)
key = np.random.rand(batch,seq_len,kv_nh128).astype(np.float16)
value = np.random.rand(batch,seq_len,kv_nh*128).astype(np.float16)
past_key = np.random.rand(batch,kv_nh,past,128).astype(np.float16)
past_value = np.random.rand(batch,kv_nh,past,128).astype(np.float16)
seqlens_k = np.array([[seq_len]], dtype=np.int32)
total_sequence_length = np.array(tot_seq_len, dtype=np.int32)
cos_cache = np.random.rand(max_seq_len, 64).astype(np.float16)
sin_cache = np.random.rand(max_seq_len, 64).astype(np.float16)
inputs = {
'query': query,
'key': key,
'value': value,
'past_key': past_key,
'past_value': past_value,
'seqlens_k': seqlens_k,
'total_sequence_length': total_sequence_length,
'cos_cache': cos_cache,
'sin_cache': sin_cache,
}
outputs = session.run(None, inputs)
print("inputs")
print(inputs)
print("outputs")
print(outputs)`
Beta Was this translation helpful? Give feedback.
All reactions