Skip to content

Commit 8159723

Browse files
authored
[js/webgpu] Optimize matmulnbits (#22360)
### Description <!-- Describe your changes. --> This PR further optimizes matmulnbits specially for iGPUs. The phi3 demo becomes ~12 tokens/second from ~8 tokens on iGPUs. Some todos: 1. Make the optimization more general, Remove the blockSize = 32 limitation. 2. Tune the parameter, such as workgroupSize, components size (currently only support components = 1), to see the performance change.
1 parent 2bc3754 commit 8159723

File tree

3 files changed

+179
-3
lines changed

3 files changed

+179
-3
lines changed

js/web/lib/wasm/jsep/webgpu/ops/common.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,7 @@ class ShaderHelperImpl implements ShaderHelper {
868868
const paramList = is1DimensionDispatch
869869
? `@builtin(global_invocation_id) global_id : vec3<u32>,
870870
@builtin(workgroup_id) workgroup_id : vec3<u32>,
871+
@builtin(local_invocation_index) local_idx : u32,
871872
@builtin(local_invocation_id) local_id : vec3<u32>`
872873
: `@builtin(global_invocation_id) global_id : vec3<u32>,
873874
@builtin(local_invocation_id) local_id : vec3<u32>,
@@ -876,7 +877,6 @@ class ShaderHelperImpl implements ShaderHelper {
876877
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
877878
const globalIdxDefinition = is1DimensionDispatch
878879
? `let global_idx = global_id.x;
879-
let local_idx = local_id.x;
880880
let workgroup_index = workgroup_id.x;`
881881
: `let workgroup_index = workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
882882
workgroup_id.y * num_workgroups[0] + workgroup_id.x;

js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,185 @@ export const createMatMulNBitsProgramInfo = (
266266
};
267267
};
268268

269+
// Currently, only support blockSize = 32.
270+
export const createMatMulNBitsBlockSize32ProgramInfo = (
271+
inputs: readonly TensorView[],
272+
attributes: MatMulNBitsAttributes,
273+
): ProgramInfo => {
274+
const inputShape = inputs[0].dims;
275+
const aRank = inputShape.length;
276+
const dimAOuter = inputShape[aRank - 2];
277+
const dimInner = attributes.k;
278+
const dimBOuter = attributes.n;
279+
const batchDims = inputShape.slice(0, aRank - 2);
280+
const batchSize = ShapeUtil.size(batchDims);
281+
const blobSize = inputs[1].dims[2];
282+
const blobSizeInWords = blobSize / 4;
283+
const dataType = inputs[0].dataType;
284+
const aComponents = getMaxComponents(attributes.k);
285+
const bComponents = getMaxComponents(blobSizeInWords);
286+
const outputShape = batchDims.concat([dimAOuter, dimBOuter]);
287+
288+
const workgroupSize = 128;
289+
const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1;
290+
const workgroupX = workgroupSize / workgroupY;
291+
const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data.
292+
const aLengthPerTile = tileSize / aComponents;
293+
const blocksPerTile = tileSize / attributes.blockSize;
294+
const dispatchSize = ShapeUtil.size(outputShape) / workgroupY;
295+
296+
const programUniforms: ProgramUniform[] = [];
297+
const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
298+
const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
299+
bShape.splice(-1, 1, blobSizeInWords / bComponents);
300+
programUniforms.push(...createTensorShapeVariables(inputShapeTemp));
301+
programUniforms.push(...createTensorShapeVariables(bShape));
302+
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
303+
if (inputs.length === 4) {
304+
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
305+
}
306+
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter];
307+
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
308+
309+
const getShaderSource = (shaderHelper: ShaderHelper) => {
310+
const inputRank = inputShapeTemp.length;
311+
const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents);
312+
const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
313+
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
314+
const inputVariables = [a, b, scales];
315+
const zeroPoints =
316+
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
317+
if (zeroPoints) {
318+
inputVariables.push(zeroPoints);
319+
}
320+
const outputRank = outputShapeTemp.length;
321+
const output = outputVariable('output', inputs[0].dataType, outputRank);
322+
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
323+
const readA = () => {
324+
switch (aComponents) {
325+
case 1:
326+
return `
327+
let a_data0 = vec4<${dataType}>(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]);
328+
let a_data1 = vec4<${dataType}>(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);`;
329+
case 2:
330+
return `
331+
let a_data0 = vec4<${dataType}>(sub_a[word_offset], sub_a[word_offset + 1]);
332+
let a_data1 = vec4<${dataType}>(sub_a[word_offset + 2], sub_a[word_offset + 3]);`;
333+
case 4:
334+
return `
335+
let a_data0 = sub_a[word_offset];
336+
let a_data1 = sub_a[word_offset + 1];`;
337+
default:
338+
throw new Error(`${aComponents}-component is not supported.`);
339+
}
340+
};
341+
342+
return `
343+
var<workgroup> sub_a: array<${a.type.value}, ${aLengthPerTile}>;
344+
var<workgroup> inter_results: array<array<${output.type.value}, ${workgroupX}>, ${workgroupY}>;
345+
${shaderHelper.declareVariables(...inputVariables, output)}
346+
${shaderHelper.mainStart([workgroupX, workgroupY, 1])}
347+
let output_indices = ${output.offsetToIndices(`workgroup_index * ${workgroupY}`)};
348+
let col = output_indices[2];
349+
let row = output_indices[1];
350+
let batch = output_indices[0];
351+
let n_blocks_per_col = uniforms.b_shape[1];
352+
let num_tiles = (n_blocks_per_col - 1) / ${blocksPerTile} + 1;
353+
354+
// Loop over shared dimension.
355+
for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
356+
let a_col_start = tile * ${aLengthPerTile};
357+
// load one tile A data into shared memory.
358+
for (var a_offset = local_idx; a_offset < ${aLengthPerTile}; a_offset += ${workgroupSize})
359+
{
360+
let a_col = a_col_start + a_offset;
361+
if (a_col < uniforms.a_shape[2])
362+
{
363+
sub_a[a_offset] = ${a.getByIndices(`${a.type.indices}(batch, row, a_col)`)};
364+
} else {
365+
sub_a[a_offset] = ${a.type.value}(0);
366+
}
367+
}
368+
workgroupBarrier();
369+
370+
// each thread process one block
371+
let b_row = col + local_id.y;
372+
let block = tile * ${blocksPerTile} + local_id.x;
373+
${
374+
zeroPoints
375+
? `
376+
let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
377+
let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
378+
let zero_point_word_index = zero_point_byte_count >> 0x2u;
379+
let zero_point_byte_offset = zero_point_byte_count & 0x3u;
380+
let zero_point_nibble_offset: u32 = block & 0x1u;
381+
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
382+
let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
383+
let zero_point = ${dataType}((zero_point_word) & 0xFu);`
384+
: `
385+
// The default zero point is 8 for unsigned 4-bit quantization.
386+
let zero_point = ${dataType}(${8.0});`
387+
}
388+
let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
389+
let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
390+
var word_offset = local_id.x * ${attributes.blockSize / aComponents};
391+
for (var i: u32 = 0; i < ${bComponents}; i++) {
392+
${readA()}
393+
let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`};
394+
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
395+
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
396+
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
397+
{ length: 4 },
398+
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
399+
).join(', ')});
400+
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
401+
inter_results[local_id.y][local_id.x] += ${Array.from(
402+
{ length: 2 },
403+
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
404+
).join(' + ')};
405+
word_offset += ${8 / aComponents};
406+
}
407+
workgroupBarrier();
408+
}
409+
410+
if (local_idx < ${workgroupY}) {
411+
var output_value: ${output.type.value} = ${output.type.value}(0);
412+
for (var b = 0u; b < ${workgroupX}; b++) {
413+
output_value += inter_results[local_idx][b];
414+
}
415+
if (col + local_idx < uniforms.output_shape[2])
416+
{
417+
${output.setByIndices(`${output.type.indices}(batch, row, col + local_idx)`, 'output_value')}
418+
}
419+
}
420+
}`;
421+
};
422+
return {
423+
name: 'BlockwiseMatMulNBits32',
424+
shaderCache: {
425+
hint: `${attributes.blockSize};${aComponents};${bComponents};${workgroupX};${workgroupY}`,
426+
inputDependencies: Array(inputs.length).fill('rank'),
427+
},
428+
getRunData: () => ({
429+
outputs: [{ dims: outputShape, dataType }],
430+
dispatchGroup: { x: dispatchSize },
431+
programUniforms,
432+
}),
433+
getShaderSource,
434+
};
435+
};
436+
269437
export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
270438
validateInputs(context.inputs, attributes);
271-
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
439+
if (
440+
attributes.blockSize === 32 &&
441+
context.adapterInfo.isVendor('intel') &&
442+
context.adapterInfo.isArchitecture('gen-12lp')
443+
) {
444+
context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes));
445+
} else {
446+
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
447+
}
272448
};
273449

274450
export const parseMatMulNBitsAttributes = (attributes: Record<string, unknown>): MatMulNBitsAttributes =>

js/web/lib/wasm/jsep/webgpu/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export enum GpuDataType {
1515
}
1616
export type GpuDataId = number;
1717

18-
export type GpuArchitecture = 'ampere';
18+
export type GpuArchitecture = 'ampere' | 'gen-12lp';
1919
export type GpuVendor = 'amd' | 'intel' | 'nvidia';
2020
export interface AdapterInfo {
2121
isArchitecture: (architecture: GpuArchitecture) => boolean;

0 commit comments

Comments
 (0)