@@ -266,9 +266,185 @@ export const createMatMulNBitsProgramInfo = (
266
266
} ;
267
267
} ;
268
268
269
+ // Currently, only support blockSize = 32.
270
+ export const createMatMulNBitsBlockSize32ProgramInfo = (
271
+ inputs : readonly TensorView [ ] ,
272
+ attributes : MatMulNBitsAttributes ,
273
+ ) : ProgramInfo => {
274
+ const inputShape = inputs [ 0 ] . dims ;
275
+ const aRank = inputShape . length ;
276
+ const dimAOuter = inputShape [ aRank - 2 ] ;
277
+ const dimInner = attributes . k ;
278
+ const dimBOuter = attributes . n ;
279
+ const batchDims = inputShape . slice ( 0 , aRank - 2 ) ;
280
+ const batchSize = ShapeUtil . size ( batchDims ) ;
281
+ const blobSize = inputs [ 1 ] . dims [ 2 ] ;
282
+ const blobSizeInWords = blobSize / 4 ;
283
+ const dataType = inputs [ 0 ] . dataType ;
284
+ const aComponents = getMaxComponents ( attributes . k ) ;
285
+ const bComponents = getMaxComponents ( blobSizeInWords ) ;
286
+ const outputShape = batchDims . concat ( [ dimAOuter , dimBOuter ] ) ;
287
+
288
+ const workgroupSize = 128 ;
289
+ const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1 ;
290
+ const workgroupX = workgroupSize / workgroupY ;
291
+ const tileSize = workgroupX * bComponents * 8 ; // each uint32 has 8 data.
292
+ const aLengthPerTile = tileSize / aComponents ;
293
+ const blocksPerTile = tileSize / attributes . blockSize ;
294
+ const dispatchSize = ShapeUtil . size ( outputShape ) / workgroupY ;
295
+
296
+ const programUniforms : ProgramUniform [ ] = [ ] ;
297
+ const inputShapeTemp = [ batchSize , dimAOuter , dimInner / aComponents ] ;
298
+ const bShape = ShapeUtil . convertShape ( inputs [ 1 ] . dims ) . slice ( ) ;
299
+ bShape . splice ( - 1 , 1 , blobSizeInWords / bComponents ) ;
300
+ programUniforms . push ( ...createTensorShapeVariables ( inputShapeTemp ) ) ;
301
+ programUniforms . push ( ...createTensorShapeVariables ( bShape ) ) ;
302
+ programUniforms . push ( ...createTensorShapeVariables ( inputs [ 2 ] . dims ) ) ;
303
+ if ( inputs . length === 4 ) {
304
+ programUniforms . push ( ...createTensorShapeVariables ( ShapeUtil . convertShape ( inputs [ 3 ] . dims ) ) ) ;
305
+ }
306
+ const outputShapeTemp = [ batchSize , dimAOuter , dimBOuter ] ;
307
+ programUniforms . push ( ...createTensorShapeVariables ( outputShapeTemp ) ) ;
308
+
309
+ const getShaderSource = ( shaderHelper : ShaderHelper ) => {
310
+ const inputRank = inputShapeTemp . length ;
311
+ const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , inputRank , aComponents ) ;
312
+ const b = inputVariable ( 'b' , DataType . uint32 , bShape . length , bComponents ) ;
313
+ const scales = inputVariable ( 'scales' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length ) ;
314
+ const inputVariables = [ a , b , scales ] ;
315
+ const zeroPoints =
316
+ inputs . length === 4 ? inputVariable ( 'zero_points' , DataType . uint32 , inputs [ 3 ] . dims . length ) : undefined ;
317
+ if ( zeroPoints ) {
318
+ inputVariables . push ( zeroPoints ) ;
319
+ }
320
+ const outputRank = outputShapeTemp . length ;
321
+ const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputRank ) ;
322
+ const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
323
+ const readA = ( ) => {
324
+ switch ( aComponents ) {
325
+ case 1 :
326
+ return `
327
+ let a_data0 = vec4<${ dataType } >(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]);
328
+ let a_data1 = vec4<${ dataType } >(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);` ;
329
+ case 2 :
330
+ return `
331
+ let a_data0 = vec4<${ dataType } >(sub_a[word_offset], sub_a[word_offset + 1]);
332
+ let a_data1 = vec4<${ dataType } >(sub_a[word_offset + 2], sub_a[word_offset + 3]);` ;
333
+ case 4 :
334
+ return `
335
+ let a_data0 = sub_a[word_offset];
336
+ let a_data1 = sub_a[word_offset + 1];` ;
337
+ default :
338
+ throw new Error ( `${ aComponents } -component is not supported.` ) ;
339
+ }
340
+ } ;
341
+
342
+ return `
343
+ var<workgroup> sub_a: array<${ a . type . value } , ${ aLengthPerTile } >;
344
+ var<workgroup> inter_results: array<array<${ output . type . value } , ${ workgroupX } >, ${ workgroupY } >;
345
+ ${ shaderHelper . declareVariables ( ...inputVariables , output ) }
346
+ ${ shaderHelper . mainStart ( [ workgroupX , workgroupY , 1 ] ) }
347
+ let output_indices = ${ output . offsetToIndices ( `workgroup_index * ${ workgroupY } ` ) } ;
348
+ let col = output_indices[2];
349
+ let row = output_indices[1];
350
+ let batch = output_indices[0];
351
+ let n_blocks_per_col = uniforms.b_shape[1];
352
+ let num_tiles = (n_blocks_per_col - 1) / ${ blocksPerTile } + 1;
353
+
354
+ // Loop over shared dimension.
355
+ for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
356
+ let a_col_start = tile * ${ aLengthPerTile } ;
357
+ // load one tile A data into shared memory.
358
+ for (var a_offset = local_idx; a_offset < ${ aLengthPerTile } ; a_offset += ${ workgroupSize } )
359
+ {
360
+ let a_col = a_col_start + a_offset;
361
+ if (a_col < uniforms.a_shape[2])
362
+ {
363
+ sub_a[a_offset] = ${ a . getByIndices ( `${ a . type . indices } (batch, row, a_col)` ) } ;
364
+ } else {
365
+ sub_a[a_offset] = ${ a . type . value } (0);
366
+ }
367
+ }
368
+ workgroupBarrier();
369
+
370
+ // each thread process one block
371
+ let b_row = col + local_id.y;
372
+ let block = tile * ${ blocksPerTile } + local_id.x;
373
+ ${
374
+ zeroPoints
375
+ ? `
376
+ let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
377
+ let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
378
+ let zero_point_word_index = zero_point_byte_count >> 0x2u;
379
+ let zero_point_byte_offset = zero_point_byte_count & 0x3u;
380
+ let zero_point_nibble_offset: u32 = block & 0x1u;
381
+ let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
382
+ let zero_point_word = ${ zeroPoints . getByOffset ( 'zero_point_word_index' ) } >> zero_point_bits_offset;
383
+ let zero_point = ${ dataType } ((zero_point_word) & 0xFu);`
384
+ : `
385
+ // The default zero point is 8 for unsigned 4-bit quantization.
386
+ let zero_point = ${ dataType } (${ 8.0 } );`
387
+ }
388
+ let scale = ${ scales . getByOffset ( `b_row * n_blocks_per_col + block` ) } ;
389
+ let b_data = ${ b . getByIndices ( `${ b . type . indices } (b_row, block, 0)` ) } ;
390
+ var word_offset = local_id.x * ${ attributes . blockSize / aComponents } ;
391
+ for (var i: u32 = 0; i < ${ bComponents } ; i++) {
392
+ ${ readA ( ) }
393
+ let b_value = ${ bComponents === 1 ? `b_data` : `b_data[i]` } ;
394
+ let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
395
+ let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
396
+ let b_quantized_values = mat2x4<${ dataType } >(${ Array . from (
397
+ { length : 4 } ,
398
+ ( _ , i ) => `${ dataType } (b_value_lower[${ i } ]), ${ dataType } (b_value_upper[${ i } ])` ,
399
+ ) . join ( ', ' ) } );
400
+ let b_dequantized_values = (b_quantized_values - mat2x4<${ dataType } >(${ Array ( 8 ) . fill ( 'zero_point' ) . join ( ',' ) } )) * scale;
401
+ inter_results[local_id.y][local_id.x] += ${ Array . from (
402
+ { length : 2 } ,
403
+ ( _ , i ) => `${ `dot(a_data${ i } , b_dequantized_values[${ i } ])` } ` ,
404
+ ) . join ( ' + ' ) } ;
405
+ word_offset += ${ 8 / aComponents } ;
406
+ }
407
+ workgroupBarrier();
408
+ }
409
+
410
+ if (local_idx < ${ workgroupY } ) {
411
+ var output_value: ${ output . type . value } = ${ output . type . value } (0);
412
+ for (var b = 0u; b < ${ workgroupX } ; b++) {
413
+ output_value += inter_results[local_idx][b];
414
+ }
415
+ if (col + local_idx < uniforms.output_shape[2])
416
+ {
417
+ ${ output . setByIndices ( `${ output . type . indices } (batch, row, col + local_idx)` , 'output_value' ) }
418
+ }
419
+ }
420
+ }` ;
421
+ } ;
422
+ return {
423
+ name : 'BlockwiseMatMulNBits32' ,
424
+ shaderCache : {
425
+ hint : `${ attributes . blockSize } ;${ aComponents } ;${ bComponents } ;${ workgroupX } ;${ workgroupY } ` ,
426
+ inputDependencies : Array ( inputs . length ) . fill ( 'rank' ) ,
427
+ } ,
428
+ getRunData : ( ) => ( {
429
+ outputs : [ { dims : outputShape , dataType } ] ,
430
+ dispatchGroup : { x : dispatchSize } ,
431
+ programUniforms,
432
+ } ) ,
433
+ getShaderSource,
434
+ } ;
435
+ } ;
436
+
269
437
export const matMulNBits = ( context : ComputeContext , attributes : MatMulNBitsAttributes ) : void => {
270
438
validateInputs ( context . inputs , attributes ) ;
271
- context . compute ( createMatMulNBitsProgramInfo ( context . inputs , attributes ) ) ;
439
+ if (
440
+ attributes . blockSize === 32 &&
441
+ context . adapterInfo . isVendor ( 'intel' ) &&
442
+ context . adapterInfo . isArchitecture ( 'gen-12lp' )
443
+ ) {
444
+ context . compute ( createMatMulNBitsBlockSize32ProgramInfo ( context . inputs , attributes ) ) ;
445
+ } else {
446
+ context . compute ( createMatMulNBitsProgramInfo ( context . inputs , attributes ) ) ;
447
+ }
272
448
} ;
273
449
274
450
export const parseMatMulNBitsAttributes = ( attributes : Record < string , unknown > ) : MatMulNBitsAttributes =>
0 commit comments