2727@triton .jit
2828def dequant_8bit_kernel (
2929 a_ptr ,
30- c_ptr ,
31- quant_ptr ,
30+ out_ptr ,
31+ code_ptr ,
3232 absmax_ptr ,
33- num_paired_elements ,
33+ n ,
3434 QUANT_BLOCK : tl .constexpr ,
3535 SPLIT_SIZE : tl .constexpr ,
3636):
3737 pid = tl .program_id (axis = 0 )
3838 block_start = pid * SPLIT_SIZE
3939 offsets = block_start + tl .arange (0 , SPLIT_SIZE )
40- mask = offsets < num_paired_elements
41-
42- a = tl .load (a_ptr + offsets , mask )
43- a = a .to (tl .uint8 )
44-
45- # apply conversion
46- scaled_int8 = tl .load (quant_ptr + a , mask )
47-
48- abs_blocks_lim = (num_paired_elements // QUANT_BLOCK ) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
49- abs_offsets = offsets // QUANT_BLOCK
50- mask_blocked = offsets < abs_blocks_lim
51-
52- absmax = tl .load (absmax_ptr + abs_offsets , mask_blocked )
53- # apply scales
54- out_dq = scaled_int8 * absmax
55-
56- offs = block_start + tl .arange (0 , SPLIT_SIZE )
57- mask = offs < num_paired_elements
58- tl .store (c_ptr + offs , out_dq , mask )
40+ mask = offsets < n
41+ out_dq = dequant_8bit_blockwise_kernel_util (a_ptr , offsets , code_ptr , absmax_ptr , mask , QUANT_BLOCK )
42+ tl .store (out_ptr + offsets , out_dq , mask )
5943
6044
6145def dequant_8bit_blockwise (
@@ -66,21 +50,21 @@ def dequant_8bit_blockwise(
6650 dtype : torch .dtype = None ,
6751 out : torch .Tensor = None ,
6852):
69- number_of_paired_elements = a .numel ()
53+ n = a .numel ()
7054 if out is None :
7155 if dtype is None :
7256 raise ValueError ("If out is None, dtype must be specified" )
7357 out = torch .empty_like (a , dtype = dtype , device = a .device )
7458
7559 SPLIT_SIZE = 256
7660 # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
77- grid = (triton .cdiv (number_of_paired_elements , SPLIT_SIZE ),)
61+ grid = (triton .cdiv (n , SPLIT_SIZE ),)
7862 dequant_8bit_kernel [grid ](
7963 a ,
8064 out ,
8165 quant_state_code ,
8266 absmax ,
83- number_of_paired_elements ,
67+ n ,
8468 quant_blocksize ,
8569 SPLIT_SIZE ,
8670 )
@@ -115,39 +99,9 @@ def quantize_8bit_blockwise_kernel(
11599
116100 A = tl .load (A_ptr + offsets , mask = mask , other = 0.0 )
117101
118- # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
119- A_reshaped = tl .reshape (A , (SPLIT_NUM_BLOCKS , BLOCK_SIZE ))
120-
121- # Calculating absamax for each block
122- absmax = tl .max (tl .abs (A_reshaped ), axis = 1 )
102+ quantized , absmax = quantize_8bit_blockwise_kernel_util (A , code_ptr , CODE_SIZE , BLOCK_SIZE , SPLIT_NUM_BLOCKS )
123103 tl .store (absmax_ptr + block_start_idx + tl .arange (0 , SPLIT_NUM_BLOCKS ), absmax )
124-
125- A_normalized = A_reshaped / absmax [:, None ]
126- A_normalized = tl .clamp (A_normalized , - 1.0 , 1.0 )
127-
128- lower_pivot = tl .zeros ((SPLIT_NUM_BLOCKS , BLOCK_SIZE ), dtype = tl .int32 )
129- upper_pivot = tl .full ((SPLIT_NUM_BLOCKS , BLOCK_SIZE ), CODE_SIZE - 1 , dtype = tl .int32 )
130-
131- for _ in range (8 ): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
132- pivot = (lower_pivot + upper_pivot ) // 2
133- val = tl .load (code_ptr + pivot )
134- is_higher = A_normalized > val # code[pivot]
135- lower_pivot = tl .where (is_higher , pivot , lower_pivot )
136- upper_pivot = tl .where (is_higher , upper_pivot , pivot )
137-
138- # Choose closest level
139- lower_val = tl .load (code_ptr + lower_pivot )
140- upper_val = tl .load (code_ptr + upper_pivot )
141- lower_dist = tl .abs (A_normalized - lower_val )
142- upper_dist = tl .abs (A_normalized - upper_val )
143- quantized = tl .where (lower_dist <= upper_dist , lower_pivot , upper_pivot ).to (tl .uint8 )
144-
145- # too slow approach
146- # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
147- # quantized = tl.argmin(diff, axis=2).to(tl.uint8)
148-
149- quantized_flat = tl .reshape (quantized , (BLOCK_SIZE * SPLIT_NUM_BLOCKS ,))
150- tl .store (out_ptr + offsets , quantized_flat , mask = mask )
104+ tl .store (out_ptr + offsets , quantized , mask = mask )
151105
152106
153107def quantize_blockwise_triton (A , code , blocksize , absmax = None , out = None ):
@@ -180,17 +134,17 @@ def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
180134
181135
182136@triton .jit
183- def quantize_8bit_blockwise_core (
137+ def quantize_8bit_blockwise_kernel_util (
184138 a ,
185- qmap_ptr ,
139+ code_ptr ,
186140 CODE_SIZE : tl .constexpr ,
187141 BLOCK_SIZE : tl .constexpr ,
188142 N_PER_TH : tl .constexpr ,
189143):
190144 # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
191145 a_reshaped = tl .reshape (a , (N_PER_TH , BLOCK_SIZE ))
192146
193- # Calculating absamax for each block
147+ # Calculating absmax for each block
194148 absmax = tl .max (tl .abs (a_reshaped ), axis = 1 )
195149
196150 a_normalized = a_reshaped / absmax [:, None ]
@@ -202,37 +156,40 @@ def quantize_8bit_blockwise_core(
202156 # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
203157 for _ in range (8 ):
204158 pivot = (lower_pivot + upper_pivot ) // 2
205- val = tl .load (qmap_ptr + pivot )
159+ val = tl .load (code_ptr + pivot )
206160 is_higher = a_normalized > val # code[pivot]
207161 lower_pivot = tl .where (is_higher , pivot , lower_pivot )
208162 upper_pivot = tl .where (is_higher , upper_pivot , pivot )
209163
210164 # Choose closest level
211- lower_val = tl .load (qmap_ptr + lower_pivot )
212- upper_val = tl .load (qmap_ptr + upper_pivot )
165+ lower_val = tl .load (code_ptr + lower_pivot )
166+ upper_val = tl .load (code_ptr + upper_pivot )
213167 lower_dist = tl .abs (a_normalized - lower_val )
214168 upper_dist = tl .abs (a_normalized - upper_val )
215169 quantized = tl .where (lower_dist <= upper_dist , lower_pivot , upper_pivot ).to (tl .uint8 )
216170
171+ # too slow approach
172+ # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
173+ # quantized = tl.argmin(diff, axis=2).to(tl.uint8)
174+
217175 quantized_flat = tl .reshape (quantized , (BLOCK_SIZE * N_PER_TH ,))
218176 return quantized_flat , absmax
219177
220178
221179@triton .jit
222- def dequant_8bit_kernel_util (
223- codes_ptr ,
180+ def dequant_8bit_blockwise_kernel_util (
181+ a_ptr ,
224182 offsets ,
225- qmap_ptr ,
183+ code_ptr ,
226184 absmax_ptr ,
227185 mask ,
228186 BLOCK_SIZE : tl .constexpr ,
229187):
230- codes = tl .load (codes_ptr + offsets , mask , other = 0 ).to (tl .uint8 )
231- abs_offsets = offsets // BLOCK_SIZE
232- absmax = tl .load (absmax_ptr + abs_offsets , mask = mask , other = 0.0 , eviction_policy = "evict_last" )
233-
234- # apply conversion
235- scaled_int8 = tl .load (qmap_ptr + codes , mask )
236- # apply scales
188+ a = tl .load (a_ptr + offsets , mask , other = 0 ).to (tl .uint8 )
189+ scaled_int8 = tl .load (code_ptr + a , mask )
190+ # Load scales
191+ absmax_offsets = offsets // BLOCK_SIZE
192+ absmax = tl .load (absmax_ptr + absmax_offsets , mask = mask , other = 0.0 , eviction_policy = "evict_last" )
193+ # Apply scales
237194 out_dq = scaled_int8 * absmax
238195 return out_dq
0 commit comments