Skip to content

Commit 21b3cbc

Browse files
[WIP][JS/WebGPU] Inputs Key and Value could be 4-dims. (#20470)
### Description The Key and Value inputs could be 4-dims ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 2c19db0 commit 21b3cbc

File tree

3 files changed

+239
-11
lines changed

3 files changed

+239
-11
lines changed

js/web/lib/wasm/jsep/webgpu/ops/attention.ts

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,12 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
282282
})()};
283283
workgroupBarrier();
284284
285-
var max_value = -3.402823e+38f;
285+
var max_value = f32(-3.402823e+38f);
286286
for (var i = 0u; i < ${WG}; i++) {
287287
max_value = max(thread_max[i], max_value);
288288
}
289289
290-
var sum_vector = ${f32Type}(${0});
290+
var sum_vector = ${f32Type}(0);
291291
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
292292
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
293293
}
@@ -378,7 +378,6 @@ const createAttentionProbsProgramInfo =
378378
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
379379
];
380380
return `
381-
const beta: ${dataType} = 1.0;
382381
const TILE_SIZE = ${TILE_SIZE}u;
383382
384383
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
@@ -426,16 +425,16 @@ const createAttentionProbsProgramInfo =
426425
throw new Error(`Unsupported components: ${components}`);
427426
}
428427
})()};
429-
output[outputIdx] = sum * uniforms.alpha;
428+
430429
${(() => {
431430
if (relativePositionBiasInput) {
432431
return `
433432
let batch = workgroup_id.z / uniforms.num_heads;
434433
let head = workgroup_id.z % uniforms.num_heads;
435434
var indices = ${relativePositionBiasInput.type.indices}(batch, head, global_id.y, global_id.x);
436-
output[outputIdx] += ${relativePositionBiasInput.getByIndices('indices')};`;
435+
output[outputIdx] = sum * uniforms.alpha + ${relativePositionBiasInput.getByIndices('indices')};`;
437436
}
438-
return '';
437+
return 'output[outputIdx] = sum * uniforms.alpha;';
439438
})()}
440439
}
441440
}`;
@@ -512,7 +511,6 @@ const createVxAttentionScoreProgramInfo =
512511
// we need to transpose output from BNSH_v to BSND_v
513512
let batchIdx = workgroup_id.z / uniforms.num_heads;
514513
let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
515-
let headOffset = (batchIdx * uniforms.M * uniforms.num_heads + currentBatchHeadNumber) * uniforms.N;
516514
if (m < uniforms.M && n < uniforms.N) {
517515
let outputIdx = batchIdx * uniforms.M *uniforms.v_hidden_size + m * uniforms.v_hidden_size
518516
+ currentBatchHeadNumber * uniforms.N + n;

js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
339339

340340
if (kvBNSH) {
341341
return applyAttention(
342-
context, Q, key, value, keyPaddingMask, undefined, undefined, undefined, relativePositionBias, params,
342+
context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, relativePositionBias, params,
343343
attributes);
344344
}
345345
if (!key || !value) {

js/web/test/data/ops/multihead-attention.jsonc

Lines changed: 233 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@
604604
]
605605
},
606606
{
607-
"name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey and pastValue",
607+
"name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey, pastValue, presentKey and presentValue",
608608
"operator": "MultiHeadAttention",
609609
"opset": { "domain": "com.microsoft", "version": 1 },
610610
"attributes": [{ "name": "num_heads", "data": 4, "type": "int" }],
@@ -765,7 +765,83 @@
765765
]
766766
},
767767
{
768-
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey and PastValue",
768+
"name": "MultiHeadAttention Basic, one head and head-size one with RelativePositionBias, pastKey, pastValue, presentKey and presentValue",
769+
"operator": "MultiHeadAttention",
770+
"opset": { "domain": "com.microsoft", "version": 1 },
771+
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
772+
"cases": [
773+
{
774+
"name": "T[0]",
775+
"inputs": [
776+
// Q
777+
{
778+
"data": [1.0],
779+
"dims": [1, 1, 1],
780+
"type": "float32"
781+
},
782+
// K
783+
{
784+
"data": [2.0],
785+
"dims": [1, 1, 1],
786+
"type": "float32"
787+
},
788+
// V
789+
{
790+
"data": [3.0],
791+
"dims": [1, 1, 1],
792+
"type": "float32"
793+
},
794+
// Bias
795+
{
796+
"data": null,
797+
"type": "float32"
798+
},
799+
// Mask
800+
{
801+
"data": null,
802+
"type": "int32"
803+
},
804+
// RelativePositionBias
805+
{
806+
"data": [10, 20],
807+
"dims": [1, 1, 1, 2],
808+
"type": "float32"
809+
},
810+
// PastKey
811+
{
812+
"data": [4.0],
813+
"dims": [1, 1, 1, 1],
814+
"type": "float32"
815+
},
816+
// PastValue
817+
{
818+
"data": [5.0],
819+
"dims": [1, 1, 1, 1],
820+
"type": "float32"
821+
}
822+
],
823+
"outputs": [
824+
{
825+
"data": [3.0006706714630127],
826+
"dims": [1, 1, 1],
827+
"type": "float32"
828+
},
829+
{
830+
"data": [4, 2],
831+
"dims": [1, 1, 2, 1],
832+
"type": "float32"
833+
},
834+
{
835+
"data": [5, 3],
836+
"dims": [1, 1, 2, 1],
837+
"type": "float32"
838+
}
839+
]
840+
}
841+
]
842+
},
843+
{
844+
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue, PresentKey and PresentValue",
769845
"operator": "MultiHeadAttention",
770846
"opset": { "domain": "com.microsoft", "version": 1 },
771847
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@@ -803,7 +879,7 @@
803879
},
804880
// RelativePositionBias
805881
{
806-
"data": [10, 20],
882+
"data": [100, 200],
807883
"dims": [1, 1, 1, 2],
808884
"type": "float32"
809885
},
@@ -821,8 +897,162 @@
821897
}
822898
],
823899
"outputs": [
900+
{
901+
"data": [9, 10, 11, 12],
902+
"dims": [1, 1, 4],
903+
"type": "float32"
904+
},
905+
// Present key
906+
{
907+
"data": [13, 14, 15, 16, 5, 6, 7, 8],
908+
"dims": [1, 1, 2, 4],
909+
"type": "float32"
910+
},
911+
// Present value
912+
{
913+
"data": [17, 18, 19, 20, 9, 10, 11, 12],
914+
"dims": [1, 1, 2, 4],
915+
"type": "float32"
916+
}
917+
]
918+
}
919+
]
920+
},
921+
{
922+
"name": "MultiHeadAttention Basic, one head and head-size one with pastKey and pastValue; kvBNSH (4-dim Key and Value, 3-dim Q)",
923+
"operator": "MultiHeadAttention",
924+
"opset": { "domain": "com.microsoft", "version": 1 },
925+
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
926+
"cases": [
927+
{
928+
"name": "T[0]",
929+
"inputs": [
930+
// Q
931+
{
932+
"data": [1.0],
933+
"dims": [1, 1, 1],
934+
"type": "float32"
935+
},
936+
// K
937+
{
938+
"data": [2.0],
939+
"dims": [1, 1, 1, 1],
940+
"type": "float32"
941+
},
942+
// V
943+
{
944+
"data": [3.0],
945+
"dims": [1, 1, 1, 1],
946+
"type": "float32"
947+
},
948+
// Bias
949+
{
950+
"data": null,
951+
"type": "float32"
952+
},
953+
// Mask
954+
{
955+
"data": null,
956+
"type": "int32"
957+
},
958+
// RelativePositionBias
959+
{
960+
"data": [10, 20],
961+
"dims": [1, 1, 1, 2],
962+
"type": "float32"
963+
},
964+
// PastKey
965+
{
966+
"data": [4.0],
967+
"dims": [1, 1, 1, 1],
968+
"type": "float32"
969+
},
970+
// PastValue
971+
{
972+
"data": [5.0],
973+
"dims": [1, 1, 1, 1],
974+
"type": "float32"
975+
}
976+
],
977+
"outputs": [
978+
{
979+
"data": [3.0006706714630127],
980+
"dims": [1, 1, 1],
981+
"type": "float32"
982+
},
983+
{
984+
"data": [4, 2],
985+
"dims": [1, 1, 2, 1],
986+
"type": "float32"
987+
},
988+
{
989+
"data": [5, 3],
990+
"dims": [1, 1, 2, 1],
991+
"type": "float32"
992+
}
993+
]
994+
}
995+
]
996+
},
997+
{
998+
"name": "MultiHeadAttention Basic, one head and head-size 4 with pastKey and pastValue; Key and Value 4-dims",
999+
"operator": "MultiHeadAttention",
1000+
"opset": { "domain": "com.microsoft", "version": 1 },
1001+
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
1002+
"cases": [
1003+
{
1004+
"name": "T[0]",
1005+
"inputs": [
1006+
// Q
1007+
{
1008+
"data": [1, 2, 3, 4],
1009+
"dims": [1, 1, 4],
1010+
"type": "float32"
1011+
},
1012+
// K
1013+
{
1014+
"data": [5, 6, 7, 8],
1015+
"dims": [1, 1, 1, 4],
1016+
"type": "float32"
1017+
},
1018+
// V
1019+
{
1020+
"data": [9, 10, 11, 12],
1021+
"dims": [1, 1, 1, 4],
1022+
"type": "float32"
1023+
},
1024+
// Bias
1025+
{
1026+
"data": null,
1027+
"type": "float32"
1028+
},
1029+
// Mask
1030+
{
1031+
"data": null,
1032+
"type": "int32"
1033+
},
1034+
// RelativePositionBias
1035+
{
1036+
"data": [50, 100],
1037+
"dims": [1, 1, 1, 2],
1038+
"type": "float32"
1039+
},
1040+
// PastKey
1041+
{
1042+
"data": [13, 14, 15, 16],
1043+
"dims": [1, 1, 1, 4],
1044+
"type": "float32"
1045+
},
1046+
// PastValue
8241047
{
8251048
"data": [17, 18, 19, 20],
1049+
"dims": [1, 1, 1, 4],
1050+
"type": "float32"
1051+
}
1052+
],
1053+
"outputs": [
1054+
{
1055+
"data": [9.000362396240234, 10.00036334991455, 11.000362396240234, 12.000362396240234],
8261056
"dims": [1, 1, 4],
8271057
"type": "float32"
8281058
},

0 commit comments

Comments
 (0)