Skip to content

Commit ae78cdb

Browse files
[JS/WebGPU] MultiheadAttention bugfix (#20447)
### Description Fixed pastkey, key and pastvalue, value concatenation condition and fixed index error. Added new test cases. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 33d5ea3 commit ae78cdb

File tree

2 files changed

+679
-25
lines changed

2 files changed

+679
-25
lines changed

js/web/lib/wasm/jsep/webgpu/ops/attention.ts

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,9 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
333333

334334
const createAttentionProbsProgramInfo =
335335
(_context: ComputeContext, q: TensorView, key: TensorView, relativePositionBias: TensorView|undefined,
336-
parameters: AttentionParameters, attributes: AttentionAttrs) => {
337-
const probsShape =
338-
[parameters.batchSize, parameters.numHeads, parameters.sequenceLength, parameters.totalSequenceLength];
336+
parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => {
337+
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
338+
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
339339

340340
// TODO: handle mask
341341

@@ -344,14 +344,13 @@ const createAttentionProbsProgramInfo =
344344
const vectorizedHeadSize = parameters.headSize / components;
345345
const TILE_SIZE = 12;
346346
const dispatch = {
347-
x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE),
347+
x: Math.ceil(totalSequenceLength / TILE_SIZE),
348348
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
349349
z: parameters.batchSize * parameters.numHeads
350350
};
351351
const programUniforms: ProgramUniform[] = [
352352
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
353-
{type: DataType.uint32, data: parameters.totalSequenceLength},
354-
{type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.kvSequenceLength},
353+
{type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads},
355354
{type: q.dataType, data: alpha}
356355
];
357356

@@ -376,8 +375,7 @@ const createAttentionProbsProgramInfo =
376375

377376
const uniforms: UniformsArrayType = [
378377
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
379-
{name: 'num_heads', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'},
380-
{name: 'alpha', type: dataType as UniformDataElementType}
378+
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
381379
];
382380
return `
383381
const beta: ${dataType} = 1.0;
@@ -394,7 +392,7 @@ const createAttentionProbsProgramInfo =
394392
let m = workgroup_id.y * TILE_SIZE;
395393
let n = workgroup_id.x * TILE_SIZE;
396394
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
397-
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K;
395+
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;
398396
399397
var value = ${qInput.type.value}(0);
400398
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@@ -456,7 +454,9 @@ const createAttentionProbsProgramInfo =
456454

457455

458456
const createVxAttentionScoreProgramInfo =
459-
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => {
457+
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters,
458+
pastSequenceLength: number) => {
459+
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
460460
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
461461
const TILE_SIZE = 12;
462462
const dispatch = {
@@ -465,7 +465,7 @@ const createVxAttentionScoreProgramInfo =
465465
z: params.batchSize * params.numHeads
466466
};
467467
const programUniforms: ProgramUniform[] = [
468-
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
468+
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength},
469469
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
470470
{type: DataType.uint32, data: params.vHiddenSize}
471471
];
@@ -537,24 +537,25 @@ export const applyAttention =
537537
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
538538
_past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined,
539539
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
540+
const outputPresentKey = context.outputCount > 1;
541+
const outputPresentValue = context.outputCount > 2;
542+
const pastSequenceLength = (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
543+
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
540544
// Concatinate pastKey and K to produce presentKey.
541-
const presentKeyShape =
542-
[parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize];
545+
const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
543546
const concatKeyInputs = pastKey ? [pastKey, k] : [k];
544-
const key = (context.outputCount > 1 || pastKey) ?
545-
context.compute(
546-
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
547-
{inputs: concatKeyInputs, outputs: [context.outputCount > 1 ? 1 : -1]})[0] :
548-
k;
547+
const key = outputPresentKey ? context.compute(
548+
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
549+
{inputs: concatKeyInputs, outputs: [1]})[0] :
550+
k;
549551

550552
// Concatinate pastValue and V to produce presentValue.
551-
const presentValueShape =
552-
[parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize];
553+
const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
553554
const concatValueInputs = pastValue ? [pastValue, v] : [v];
554-
const value = (context.outputCount > 2 || pastValue) ?
555+
const value = outputPresentValue ?
555556
context.compute(
556557
createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType),
557-
{inputs: concatValueInputs, outputs: [context.outputCount > 2 ? 2 : -1]})[0] :
558+
{inputs: concatValueInputs, outputs: [2]})[0] :
558559
v;
559560
const inputsK = [q, key];
560561
if (relativePositionBias) {
@@ -563,20 +564,22 @@ export const applyAttention =
563564

564565
// Run AttentionProbs
565566
const probs = context.compute(
566-
createAttentionProbsProgramInfo(context, q, key, relativePositionBias, parameters, attributes),
567+
createAttentionProbsProgramInfo(
568+
context, q, key, relativePositionBias, parameters, attributes, pastSequenceLength),
567569
{inputs: inputsK, outputs: [-1]})[0];
568570

569571
// Run Softmax
570572
context.compute(
571573
createInPlaceSoftmaxProgramInfo(
572574
context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
573-
parameters.totalSequenceLength),
575+
totalSequenceLength),
574576
{inputs: [probs], outputs: []});
575577

576578
// Run AttrionScore
577579
const inputsV = [probs, value];
578580
context.compute(
579-
createVxAttentionScoreProgramInfo(context, probs, value, parameters), {inputs: inputsV, outputs: [0]});
581+
createVxAttentionScoreProgramInfo(context, probs, value, parameters, pastSequenceLength),
582+
{inputs: inputsV, outputs: [0]});
580583
};
581584

582585
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {

0 commit comments

Comments
 (0)