@@ -333,9 +333,9 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
333
333
334
334
const createAttentionProbsProgramInfo =
335
335
( _context : ComputeContext , q : TensorView , key : TensorView , relativePositionBias : TensorView | undefined ,
336
- parameters : AttentionParameters , attributes : AttentionAttrs ) => {
337
- const probsShape =
338
- [ parameters . batchSize , parameters . numHeads , parameters . sequenceLength , parameters . totalSequenceLength ] ;
336
+ parameters : AttentionParameters , attributes : AttentionAttrs , pastSequenceLength : number ) => {
337
+ const totalSequenceLength = pastSequenceLength + parameters . kvSequenceLength ;
338
+ const probsShape = [ parameters . batchSize , parameters . numHeads , parameters . sequenceLength , totalSequenceLength ] ;
339
339
340
340
// TODO: handle mask
341
341
@@ -344,14 +344,13 @@ const createAttentionProbsProgramInfo =
344
344
const vectorizedHeadSize = parameters . headSize / components ;
345
345
const TILE_SIZE = 12 ;
346
346
const dispatch = {
347
- x : Math . ceil ( parameters . totalSequenceLength / TILE_SIZE ) ,
347
+ x : Math . ceil ( totalSequenceLength / TILE_SIZE ) ,
348
348
y : Math . ceil ( parameters . sequenceLength / TILE_SIZE ) ,
349
349
z : parameters . batchSize * parameters . numHeads
350
350
} ;
351
351
const programUniforms : ProgramUniform [ ] = [
352
352
{ type : DataType . uint32 , data : parameters . sequenceLength } , { type : DataType . uint32 , data : vectorizedHeadSize } ,
353
- { type : DataType . uint32 , data : parameters . totalSequenceLength } ,
354
- { type : DataType . uint32 , data : parameters . numHeads } , { type : DataType . uint32 , data : parameters . kvSequenceLength } ,
353
+ { type : DataType . uint32 , data : totalSequenceLength } , { type : DataType . uint32 , data : parameters . numHeads } ,
355
354
{ type : q . dataType , data : alpha }
356
355
] ;
357
356
@@ -376,8 +375,7 @@ const createAttentionProbsProgramInfo =
376
375
377
376
const uniforms : UniformsArrayType = [
378
377
{ name : 'M' , type : 'u32' } , { name : 'K' , type : 'u32' } , { name : 'N' , type : 'u32' } ,
379
- { name : 'num_heads' , type : 'u32' } , { name : 'kv_sequence_length' , type : 'u32' } ,
380
- { name : 'alpha' , type : dataType as UniformDataElementType }
378
+ { name : 'num_heads' , type : 'u32' } , { name : 'alpha' , type : dataType as UniformDataElementType }
381
379
] ;
382
380
return `
383
381
const beta: ${ dataType } = 1.0;
@@ -394,7 +392,7 @@ const createAttentionProbsProgramInfo =
394
392
let m = workgroup_id.y * TILE_SIZE;
395
393
let n = workgroup_id.x * TILE_SIZE;
396
394
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
397
- let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K;
395
+ let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;
398
396
399
397
var value = ${ qInput . type . value } (0);
400
398
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@@ -456,7 +454,9 @@ const createAttentionProbsProgramInfo =
456
454
457
455
458
456
const createVxAttentionScoreProgramInfo =
459
- ( _context : ComputeContext , probs : TensorView , v : TensorView , params : AttentionParameters ) => {
457
+ ( _context : ComputeContext , probs : TensorView , v : TensorView , params : AttentionParameters ,
458
+ pastSequenceLength : number ) => {
459
+ const totalSequenceLength = pastSequenceLength + params . kvSequenceLength ;
460
460
const outputShape = [ params . batchSize , params . sequenceLength , params . vHiddenSize ] ;
461
461
const TILE_SIZE = 12 ;
462
462
const dispatch = {
@@ -465,7 +465,7 @@ const createVxAttentionScoreProgramInfo =
465
465
z : params . batchSize * params . numHeads
466
466
} ;
467
467
const programUniforms : ProgramUniform [ ] = [
468
- { type : DataType . uint32 , data : params . sequenceLength } , { type : DataType . uint32 , data : params . totalSequenceLength } ,
468
+ { type : DataType . uint32 , data : params . sequenceLength } , { type : DataType . uint32 , data : totalSequenceLength } ,
469
469
{ type : DataType . uint32 , data : params . vHeadSize } , { type : DataType . uint32 , data : params . numHeads } ,
470
470
{ type : DataType . uint32 , data : params . vHiddenSize }
471
471
] ;
@@ -537,24 +537,25 @@ export const applyAttention =
537
537
( context : ComputeContext , q : TensorView , k : TensorView , v : TensorView , _maskIndex : TensorView | undefined ,
538
538
_past : TensorView | undefined , pastKey : TensorView | undefined , pastValue : TensorView | undefined ,
539
539
relativePositionBias : TensorView | undefined , parameters : AttentionParameters , attributes : AttentionAttrs ) => {
540
+ const outputPresentKey = context . outputCount > 1 ;
541
+ const outputPresentValue = context . outputCount > 2 ;
542
+ const pastSequenceLength = ( outputPresentKey && outputPresentValue ) ? parameters . pastSequenceLength : 0 ;
543
+ const totalSequenceLength = pastSequenceLength + parameters . kvSequenceLength ;
540
544
// Concatinate pastKey and K to produce presentKey.
541
- const presentKeyShape =
542
- [ parameters . batchSize , parameters . numHeads , parameters . totalSequenceLength , parameters . headSize ] ;
545
+ const presentKeyShape = [ parameters . batchSize , parameters . numHeads , totalSequenceLength , parameters . headSize ] ;
543
546
const concatKeyInputs = pastKey ? [ pastKey , k ] : [ k ] ;
544
- const key = ( context . outputCount > 1 || pastKey ) ?
545
- context . compute (
546
- createConcatProgramInfo ( concatKeyInputs , 2 , presentKeyShape , k . dataType ) ,
547
- { inputs : concatKeyInputs , outputs : [ context . outputCount > 1 ? 1 : - 1 ] } ) [ 0 ] :
548
- k ;
547
+ const key = outputPresentKey ? context . compute (
548
+ createConcatProgramInfo ( concatKeyInputs , 2 , presentKeyShape , k . dataType ) ,
549
+ { inputs : concatKeyInputs , outputs : [ 1 ] } ) [ 0 ] :
550
+ k ;
549
551
550
552
// Concatinate pastValue and V to produce presentValue.
551
- const presentValueShape =
552
- [ parameters . batchSize , parameters . numHeads , parameters . totalSequenceLength , parameters . headSize ] ;
553
+ const presentValueShape = [ parameters . batchSize , parameters . numHeads , totalSequenceLength , parameters . headSize ] ;
553
554
const concatValueInputs = pastValue ? [ pastValue , v ] : [ v ] ;
554
- const value = ( context . outputCount > 2 || pastValue ) ?
555
+ const value = outputPresentValue ?
555
556
context . compute (
556
557
createConcatProgramInfo ( concatValueInputs , 2 , presentValueShape , v . dataType ) ,
557
- { inputs : concatValueInputs , outputs : [ context . outputCount > 2 ? 2 : - 1 ] } ) [ 0 ] :
558
+ { inputs : concatValueInputs , outputs : [ 2 ] } ) [ 0 ] :
558
559
v ;
559
560
const inputsK = [ q , key ] ;
560
561
if ( relativePositionBias ) {
@@ -563,20 +564,22 @@ export const applyAttention =
563
564
564
565
// Run AttentionProbs
565
566
const probs = context . compute (
566
- createAttentionProbsProgramInfo ( context , q , key , relativePositionBias , parameters , attributes ) ,
567
+ createAttentionProbsProgramInfo (
568
+ context , q , key , relativePositionBias , parameters , attributes , pastSequenceLength ) ,
567
569
{ inputs : inputsK , outputs : [ - 1 ] } ) [ 0 ] ;
568
570
569
571
// Run Softmax
570
572
context . compute (
571
573
createInPlaceSoftmaxProgramInfo (
572
574
context , probs , parameters . batchSize * parameters . numHeads * parameters . sequenceLength ,
573
- parameters . totalSequenceLength ) ,
575
+ totalSequenceLength ) ,
574
576
{ inputs : [ probs ] , outputs : [ ] } ) ;
575
577
576
578
// Run AttrionScore
577
579
const inputsV = [ probs , value ] ;
578
580
context . compute (
579
- createVxAttentionScoreProgramInfo ( context , probs , value , parameters ) , { inputs : inputsV , outputs : [ 0 ] } ) ;
581
+ createVxAttentionScoreProgramInfo ( context , probs , value , parameters , pastSequenceLength ) ,
582
+ { inputs : inputsV , outputs : [ 0 ] } ) ;
580
583
} ;
581
584
582
585
const prepare = ( context : ComputeContext , parameters : AttentionParameters ) => {
0 commit comments