Skip to content

Commit 0409c63

Browse files
authored
[js/webgpu] Optimize MultiHeadAttention|Transpose (#22420)
### Description <!-- Describe your changes. --> With this optimization, 96 MultiHeadAttention|Transpose ops in phi3 disappear. Phi3 becomes 113 tokens from 107 tokens on my dGPUs. The optimization mainly skips the transpose op if one of the transposed dims is 1. Reshape is enough.
1 parent de93f40 commit 0409c63

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,9 @@ export const maybeTransposeToBNSHAndAddBias = (
338338
if (input.dims.length === 3) {
339339
reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
340340
}
341+
if (numHeads === 1 || sequenceLength === 1) {
342+
return reshapedInput;
343+
}
341344
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
342345
inputs: [reshapedInput],
343346
outputs: [-1],
@@ -356,6 +359,9 @@ export const maybeTransposeToBNSHAndAddBias = (
356359
biasOffset!,
357360
);
358361
reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]);
362+
if (numHeads === 1 || sequenceLength === 1) {
363+
return reshapedInput;
364+
}
359365
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
360366
inputs: [reshapedInput],
361367
outputs: [-1],

0 commit comments

Comments
 (0)