-
Notifications
You must be signed in to change notification settings - Fork 112
SparseAttention ONNX Contrib Op Implementation #4275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
This build is not recommended to merge 🔴 |
❌bert-mrpc-tf: ERROR - check error output2025-09-03 10:20:56.197188: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. Traceback (most recent call last): File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 359, in main() File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 306, in main graph = load_tf_graph(model_name) File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 300, in load_tf_graph graph_def.ParseFromString(f.read()) File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/lib/io/file_io.py", line 116, in read self._preread_check() File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/lib/io/file_io.py", line 77, in _preread_check self._read_buf = _pywrap_file_io.BufferedInputStream( tensorflow.python.framework.errors_impl.UnimplementedError: File system scheme '[local]' not implemented (file: '/new-saved-models/tf-misc/bert_mrpc1.pb') 🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output🔴mask-rcnn: FAILED: MIGraphX is not within tolerance - check verbose output |
| updates); | ||
| } | ||
|
|
||
| instruction_ref make_block_masks(module& mod, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An issue presents itself with applying the block mask
The output of the first GEMM has shape BNSM, where:
B = batch size
N = num. heads
S = sequence lengths
M = max cache sequence length
The block mask, once unpacked and expanded will have shape BNXX, where:
X = max_blocks * sparse_block_size
In cases when X != S and/or X != M, the block mask needs to be trimmed down to BNSM dims, so that it can be applied to the GEMM output by using a where.
The particular case that causes the issue: S = 1
When the sequence length is equal to one, the block mask needs to be sliced down to size 1 on axis 2, that is to say it should be sliced from N to N + 1. But what should N be?
This detail is not documented, but the implementation tells us that it should be past_sequence_length.
How is past_sequence_length obtained?
The operator has as input called key_total_sequence_lengths which is described as:
1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.
The past_sequence_length is obtained by subtracting the sequence length from key_total_sequence_lengths.
As a consequence, we end up in a situation where the slice start and end depend on a runtime value, making the slice dynamic.
Not sure how to circumvent this.
@TedThemistokleous
No description provided.