@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
77import { AttributeWithCacheKey , createAttributeWithCacheKey } from '../attribute-with-cache-key' ;
88import { ComputeContext , ProgramInfo , ProgramUniform } from '../types' ;
99
10- import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
10+ import { createTensorShapeVariables , getMaxComponents , inputVariable , outputVariable , ShaderHelper , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
1111
1212// TODO support quantization bits not equal to 4
1313export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
@@ -51,124 +51,190 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
5151
5252export const createMatMulNBitsProgramInfo =
5353 ( inputs : readonly TensorView [ ] , attributes : MatMulNBitsAttributes ) : ProgramInfo => {
54- const a = inputs [ 0 ] ;
55- const b = inputs [ 1 ] ;
56- const scales = inputs [ 2 ] ;
57- const aRank = a . dims . length ;
58- const outputShape = a . dims . slice ( 0 , aRank - 1 ) . concat ( attributes . n ) ;
59- const outputSize = ShapeUtil . size ( outputShape ) ;
60-
61-
54+ const inputShape = inputs [ 0 ] . dims ;
55+ const aRank = inputShape . length ;
56+ const outputShape = inputShape . slice ( 0 , aRank - 1 ) . concat ( attributes . n ) ;
57+ const m = inputShape [ aRank - 2 ] ;
58+ const blobSize = attributes . blockSize / 8 * attributes . bits ;
59+ const blobSizeInWords = blobSize / 4 ;
60+ const outputNumber = getMaxComponents ( m ) ;
61+ const components = getMaxComponents ( attributes . n ) ;
62+ const aComponents = getMaxComponents ( attributes . k ) ;
63+ const bComponents = getMaxComponents ( blobSizeInWords ) ;
64+ const outputSize = ShapeUtil . size ( outputShape ) / components / outputNumber ;
6265 const programUniforms : ProgramUniform [ ] = [
6366 { type : DataType . uint32 , data : outputSize } , { type : DataType . uint32 , data : attributes . k } ,
6467 { type : DataType . uint32 , data : attributes . n } , { type : DataType . uint32 , data : attributes . accuracyLevel } ,
6568 { type : DataType . uint32 , data : attributes . bits } , { type : DataType . uint32 , data : attributes . blockSize }
6669 ] ;
67- programUniforms . push ( ...createTensorShapeVariables ( a . dims ) ) ;
68- programUniforms . push ( ...createTensorShapeVariables ( ShapeUtil . convertShape ( b . dims ) ) ) ;
69- programUniforms . push ( ...createTensorShapeVariables ( scales . dims ) ) ;
70+ const aShape = inputShape . slice ( ) ;
71+ aShape . splice ( - 1 , 1 , attributes . k / aComponents ) ;
72+ const bShape = ShapeUtil . convertShape ( inputs [ 1 ] . dims ) . slice ( ) ;
73+ bShape . splice ( - 1 , 1 , blobSizeInWords / bComponents ) ;
74+ programUniforms . push ( ...createTensorShapeVariables ( aShape ) ) ;
75+ programUniforms . push ( ...createTensorShapeVariables ( bShape ) ) ;
76+ programUniforms . push ( ...createTensorShapeVariables ( inputs [ 2 ] . dims ) ) ;
7077 if ( inputs . length === 4 ) {
7178 programUniforms . push ( ...createTensorShapeVariables ( ShapeUtil . convertShape ( inputs [ 3 ] . dims ) ) ) ;
7279 }
73- programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
80+ const oShape = outputShape . slice ( ) ;
81+ oShape . splice ( - 1 , 1 , attributes . n / components ) ;
82+ programUniforms . push ( ...createTensorShapeVariables ( oShape ) ) ;
7483 const getShaderSource = ( shaderHelper : ShaderHelper ) => {
75- const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims . length ) ;
76- const b = inputVariable ( 'b' , DataType . uint32 , inputs [ 1 ] . dims . length ) ;
84+ const a = inputVariable ( 'a' , inputs [ 0 ] . dataType , aShape . length , aComponents ) ;
85+ const b = inputVariable ( 'b' , DataType . uint32 , bShape . length , bComponents ) ;
7786 const scales = inputVariable ( 'scales' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length ) ;
7887 const inputVariables = [ a , b , scales ] ;
7988 const zeroPoints =
8089 inputs . length === 4 ? inputVariable ( 'zero_points' , DataType . uint32 , inputs [ 3 ] . dims . length ) : undefined ;
8190 if ( zeroPoints ) {
8291 inputVariables . push ( zeroPoints ) ;
8392 }
84- const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape . length ) ;
93+ const output = outputVariable ( 'output' , inputs [ 0 ] . dataType , outputShape . length , components ) ;
8594 const uniforms : UniformsArrayType = [
86- { name : 'output_size' , type : 'u32' } , { name : 'k ' , type : 'u32' } , { name : 'n ' , type : 'u32' } ,
95+ { name : 'output_size' , type : 'u32' } , { name : 'K ' , type : 'u32' } , { name : 'N ' , type : 'u32' } ,
8796 { name : 'accuracy_level' , type : 'u32' } , { name : 'bits' , type : 'u32' } , { name : 'block_size' , type : 'u32' }
8897 ] ;
8998 const nBlocksPerCol = Math . floor ( ( attributes . k + attributes . blockSize - 1 ) / attributes . blockSize ) ;
90- const blobSize = attributes . blockSize / 8 * attributes . bits ;
91- const wordPerBlob = blobSize / 4 ;
9299 const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
93- return `
94- fn ortUnpack8x4snorm(value: u32) -> array<${ dataType } , 8>{
95- var result = array<${ dataType } , 8>();
100+
101+ const qDqDataType = ( ( ) => {
102+ switch ( aComponents ) {
103+ case 1 :
104+ return `array<${ dataType } , 8>` ;
105+ case 2 :
106+ return `mat4x2<${ dataType } >` ;
107+ case 4 :
108+ return `mat2x4<${ dataType } >` ;
109+ default :
110+ throw new Error ( `${ aComponents } -component is not supported.` ) ;
111+ }
112+ } ) ( ) ;
113+
114+ const dequantizeImpl = `
115+ fn dequantize(quantized: ${ qDqDataType } , zero_point: ${ dataType } , scale: ${ dataType } ) -> ${ qDqDataType } {
116+ ${ ( ( ) => {
117+ if ( aComponents === 1 ) {
118+ return `var dequantized = ${ qDqDataType } (${
119+ Array . from ( { length : 8 } , ( _ , i ) => `(quantized[${ i } ] - zero_point) * scale` ) . join ( ', ' ) } );
120+ return dequantized;` ;
121+ } else {
122+ return `var zero_points: ${ qDqDataType } = ${ qDqDataType } (${ Array ( 8 ) . fill ( 'zero_point' ) . join ( ',' ) } );
123+ return (quantized - zero_points) * scale;` ;
124+ }
125+ } ) ( ) }
126+ }` ;
127+ const ortUnpack8x4snormImpl = `
128+ fn ortUnpack8x4snorm(value: u32) -> ${ qDqDataType } {
129+ var quantized: ${ qDqDataType } ;
96130 var offset: u32 = 0;
97131 let count: u32 = 4;
98132 for (var i: u32 = 0; i < 8u; i++) {
99- result[i] = ${ dataType } (extractBits(value, offset, count));
133+ var result = ${ dataType } (extractBits(value, offset, count));
134+ ${ ( ( ) => {
135+ switch ( aComponents ) {
136+ case 1 :
137+ return 'quantized[i] = result;' ;
138+ case 2 :
139+ return 'quantized[i / 2][i % 2] = result;' ;
140+ case 4 :
141+ return 'quantized[i / 4][i % 4] = result;' ;
142+ default :
143+ throw new Error ( `${ aComponents } -component is not supported.` ) ;
144+ }
145+ } ) ( ) }
100146 offset += count;
101147 }
102- return result;
103- }
148+ return quantized;
149+ }` ;
150+
151+ const updateZeroPointIndex = zeroPoints ? `
152+ zero_point_offset += 4;
153+ if (zero_point_offset == 32) {
154+ zero_point_offset = 0;
155+ zero_point_index++;
156+ zero_point_word = ${ zeroPoints . getByOffset ( 'zero_point_index' ) } ;
157+ }` :
158+ '' ;
159+
160+ return `
161+ ${ dequantizeImpl } ;
162+ ${ ortUnpack8x4snormImpl } ;
104163 ${ shaderHelper . registerUniforms ( uniforms ) . declareVariables ( ...inputVariables , output ) }
105164 ${ shaderHelper . mainStart ( ) }
106165 ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size' ) }
107- var value: ${ dataType } = 0.0;
108- let output_indices = ${ output . offsetToIndices ( 'global_idx' ) } ;
109- var a_indices: ${ a . type . indices } = output_indices;
166+ var output_values: array<${ output . type . value } , ${ outputNumber } >;
167+ var output_indices = ${ output . offsetToIndices ( 'global_idx' ) } ;
110168 var n = ${ output . indicesGet ( 'output_indices' , aRank - 1 ) } ;
169+ var m = ${ output . indicesGet ( 'output_indices' , aRank - 2 ) } ;
170+ var a_indices: ${ a . type . indices } = output_indices;
111171 // Two zero points are packed into one byte because uniforms.bits <= 4.
112172 // zero_point_offset is either 0 or 4. It is bit offset within one byte.
113173 // TODO support zero_point_offset for bits > 4
114174 ${
115175 zeroPoints ? `
116- var zero_point_index: u32 = n * ((${ nBlocksPerCol } + 1) / 2) / 4;
117- var zero_point_word: u32 = ${ zeroPoints . getByOffset ( 'zero_point_index' ) } ;
118- var zero_point_offset: u32 = 0;` :
176+ var zero_point_index: u32 = n * ${ components } * ((${ nBlocksPerCol } + 1) / 2) / 4;
177+ var zero_point_word: u32 = ${ zeroPoints . getByOffset ( 'zero_point_index' ) } ;
178+ var zero_point_offset: u32 = 0;` :
119179 '' }
120- var scale_idex = n * ${ nBlocksPerCol } ;
180+ var scale_index = n * ${ nBlocksPerCol * components } ;
121181 var b_indices: ${ b . type . indices } ;
122- ${ b . indicesSet ( 'b_indices' , '0' , 'n' ) } ;
123- var block_offset: u32 = 0;
124- for (var block: u32 = 0; block < ${ nBlocksPerCol } ; block++) {
125- // The scale and zero points are computed per block.
126- let scale = ${ scales . getByOffset ( 'scale_idex' ) } ;
127- // The default zero point is 8 for unsigned 4-bit quantization.
128- let zero_point: ${ dataType } = ${
129- zeroPoints ? `${ dataType } (extractBits(zero_point_word, zero_point_offset, 4))` : 8.0 } ;
130- ${ b . indicesSet ( 'b_indices' , '1' , 'block' ) } ;
131- var word_offset: u32 = block_offset;
132- for (var word: u32 = 0; word < ${ wordPerBlob } ; word++) {
133- ${ b . indicesSet ( 'b_indices' , '2' , 'word' ) } ;
134- let b_value = ${ b . getByIndices ( 'b_indices' ) } ;
135- let b_quantized_values: array<${ dataType } , 8> = ortUnpack8x4snorm(b_value);
136- // Number of B elements per 32-bit word is 32/bits = 32/4 = 8
137- var offset: u32 = word_offset;
138- for (var i: u32 = 0; i < 8; i++) {
139- ${ a . indicesSet ( 'a_indices' , aRank - 1 , 'offset' ) } ;
140- let a_value = ${ a . getByIndices ( 'a_indices' ) } ;
141- let b_quantized_value = b_quantized_values[i];
142- let b_dequantized_value = (b_quantized_value - zero_point) * scale;
143- value += a_value * b_dequantized_value;
144- offset++;
182+ for (var c: u32 = 0; c < ${ components } ; c++) {
183+ ${ b . indicesSet ( 'b_indices' , '0' , `n * ${ components } + c` ) } ;
184+ var block_offset: u32 = 0;
185+ for (var block: u32 = 0; block < ${ nBlocksPerCol } ; block++) {
186+ // The scale and zero points are computed per block.
187+ let scale = ${ scales . getByOffset ( 'scale_index' ) } ;
188+ // The default zero point is 8 for unsigned 4-bit quantization.
189+ let zero_point = ${ dataType } (${ zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0 } );
190+ ${ b . indicesSet ( 'b_indices' , '1' , 'block' ) } ;
191+ var word_offset: u32 = block_offset;
192+ for (var word: u32 = 0; word < ${ blobSizeInWords } ; word += ${ bComponents } ) {
193+ ${ b . indicesSet ( 'b_indices' , '2' , 'word' ) } ;
194+ let b_data = ${ b . getByIndices ( 'b_indices' ) } ;
195+ for (var i: u32 = 0; i < ${ bComponents } ; i++) {
196+ let b_value = ${ bComponents === 1 ? 'b_data' : 'b_data[word + i]' } ;
197+ let b_quantized_values: ${ qDqDataType } = ortUnpack8x4snorm(b_value);
198+ let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale);
199+ // Number of B elements per 32-bit word is 32/bits = 32/4 = 8
200+ var offset: u32 = word_offset;
201+ for (var j: u32 = 0; j < 8/${ aComponents } ; j++) {
202+ ${ a . indicesSet ( 'a_indices' , aRank - 1 , `offset/${ aComponents } ` ) } ;
203+ for (var k: u32 = 0; k < ${ outputNumber } u; k++) {
204+ ${ a . indicesSet ( 'a_indices' , aRank - 2 , `m * ${ outputNumber } + k` ) } ;
205+ let a_data = ${ a . getByIndices ( 'a_indices' ) } ;
206+ output_values[k]${ components > 1 ? '[c]' : '' } += ${
207+ aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])' } ;
208+ }
209+ offset += ${ aComponents } ;
210+ }
211+ word_offset += 8;
212+ }
145213 }
146- word_offset += 8;
214+ scale_index++;
215+ ${ updateZeroPointIndex }
216+ block_offset += uniforms.block_size;
147217 }
148- scale_idex++;
218+ // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte.
149219 ${
150- zeroPoints ? `
151- if (zero_point_offset == 28) {
152- zero_point_offset = 0;
153- zero_point_index++;
154- zero_point_word = ${ zeroPoints . getByOffset ( 'zero_point_index' ) } ;
155- } else {
156- zero_point_offset += 4;
157- }` :
220+ zeroPoints ? `if (zero_point_offset % 8 > 0) {
221+ ${ updateZeroPointIndex }
222+ }` :
158223 '' }
159- block_offset += uniforms.block_size;
160- }
161- ${ output . setByOffset ( 'global_idx' , 'value' ) } ;
162- }
163- ` ;
224+ }
225+ for (var k: u32 = 0u; k < ${ outputNumber } u; k++) {
226+ ${ output . indicesSet ( 'output_indices' , aRank - 2 , `${ outputNumber + ' * m + k' } ` ) } ;
227+ ${ output . setByIndices ( 'output_indices' , 'output_values[k]' ) }
228+ }
229+ }` ;
164230 } ;
165231 return {
166232 name : 'MatMulNBits' ,
167233 shaderCache :
168234 { hint : `${ attributes . cacheKey } ;${ inputs . length } ` , inputDependencies : Array ( inputs . length ) . fill ( 'rank' ) } ,
169235 getRunData : ( ) => ( {
170236 outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
171- dispatchGroup : { x : Math . ceil ( outputSize / 64 ) } ,
237+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
172238 programUniforms
173239 } ) ,
174240 getShaderSource
0 commit comments