77from .utils import _FP4_QUANT_TABLE , _NF4_QUANT_TABLE
88
99
10- # @triton.autotune(
11- # configs=[
12- # triton.Config({'SPLIT_SIZE': 64}),
13- # triton.Config({'SPLIT_SIZE': 128}),
14- # triton.Config({'SPLIT_SIZE': 256}),
15- # triton.Config({'SPLIT_SIZE': 512}),
16- # triton.Config({'SPLIT_SIZE': 1024}),
17- # triton.Config({'SPLIT_SIZE': 2048}),
18- # triton.Config({'SPLIT_SIZE': 4096}),
19- # triton.Config({'SPLIT_SIZE': 8192}),
20- # triton.Config({'SPLIT_SIZE': 16384}),
21- # ],
22- # key=['SPLIT_SIZE'],
23- # )
2410@triton .jit
2511def dequant_8bit_kernel (
2612 a_ptr ,
2713 c_ptr ,
2814 quant_ptr ,
2915 absmax_ptr ,
30- # bias_ptr,
3116 num_paired_elements ,
3217 QUANT_BLOCK : tl .constexpr ,
3318 SPLIT_SIZE : tl .constexpr ,
3419):
35- pid = tl .program_id (axis = 0 ) # We use a 1D launch grid so axis is 0.
20+ pid = tl .program_id (axis = 0 )
3621 block_start = pid * SPLIT_SIZE
3722 offsets = block_start + tl .arange (0 , SPLIT_SIZE )
3823 mask = offsets < num_paired_elements
3924
4025 a = tl .load (a_ptr + offsets , mask )
4126 a = a .to (tl .uint8 , bitcast = True )
4227
43- # bias = tl.load(bias_ptr)
44-
4528 # apply conversion
4629 scaled_int8 = tl .load (quant_ptr + a , mask )
4730
@@ -52,7 +35,6 @@ def dequant_8bit_kernel(
5235 absmax = tl .load (absmax_ptr + abs_offsets , mask_blocked )
5336 # apply scales
5437 out_dq = scaled_int8 * absmax
55- # out_dq = out_dq + bias
5638
5739 offs = block_start + tl .arange (0 , SPLIT_SIZE )
5840 mask = offs < num_paired_elements
@@ -79,19 +61,7 @@ def dequant_int8_blockwise(
7961
8062@triton .autotune (
8163 configs = [
82- # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
83- # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
84- # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
85- #
8664 triton .Config ({"SPLIT_NUM_BLOCKS" : 1 , "grf_mode" : "auto" }, num_stages = 4 , num_warps = 32 ),
87- #
88- # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "large"}, num_stages=2, num_warps=32),
89- # # triton.Config({'SPLIT_NUM_BLOCKS': 2, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
90- # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=2, num_warps=32),
91- # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
92- # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=2, num_warps=32),
93- # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=4, num_warps=32),
94- # triton.Config({'SPLIT_NUM_BLOCKS': 8, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
9565 ],
9666 key = ["BLOCK_SIZE" ],
9767)
@@ -124,9 +94,6 @@ def quantize_blockwise_kernel(
12494 A_normalized = A_reshaped / absmax [:, None ]
12595 A_normalized = tl .clamp (A_normalized , - 1.0 , 1.0 )
12696
127- # This can be fruitful, but compiler should preload it
128- # code = tl.load(code_ptr + tl.arange(0, CODE_SIZE))
129-
13097 lower_pivot = tl .zeros ((SPLIT_NUM_BLOCKS , BLOCK_SIZE ), dtype = tl .int32 )
13198 upper_pivot = tl .full ((SPLIT_NUM_BLOCKS , BLOCK_SIZE ), CODE_SIZE - 1 , dtype = tl .int32 )
13299
@@ -176,24 +143,6 @@ def unite_2_int4(x, y):
176143 return (x & 0xF ) | (y << 4 )
177144
178145
179- # @triton.autotune(
180- # configs=[
181- # # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
182- # # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
183- # # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
184- # #
185- # triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
186- # #
187- # # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "large"}, num_stages=2, num_warps=32),
188- # # # triton.Config({'SPLIT_NUM_BLOCKS': 2, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
189- # # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=2, num_warps=32),
190- # # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
191- # # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=2, num_warps=32),
192- # # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=4, num_warps=32),
193- # # triton.Config({'SPLIT_NUM_BLOCKS': 8, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
194- # ],
195- # key=["BLOCK_SIZE"],
196- # )
197146@triton .jit
198147def quantize_4bit_blockwise_kernel (
199148 A_ptr ,
@@ -261,11 +210,6 @@ def quantize_4bit_blockwise_triton(A, blocksize, code, blocks, absmax, quantized
261210
262211 split_num_blocks = 1
263212 grid = (triton .cdiv (blocks , split_num_blocks ),)
264- # grid = (1, )
265- # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
266- # print(" blocksize, split_num_blocks: ", blocksize, split_num_blocks)
267- # print(" blocksize, split_num_blocks: ", blocksize, split_num_blocks*2)
268- # print("A shape: ", A.shape, " numel: ", n, " blocks: ", blocks)
269213 quantize_4bit_blockwise_kernel [grid ](
270214 A_ptr = A ,
271215 code_ptr = code ,
@@ -280,20 +224,6 @@ def quantize_4bit_blockwise_triton(A, blocksize, code, blocks, absmax, quantized
280224 return quantized_out , absmax
281225
282226
283- # @triton.autotune(
284- # configs=[
285- # # triton.Config({'SPLIT_SIZE': 64}),
286- # # triton.Config({'SPLIT_SIZE': 128}),
287- # # triton.Config({'SPLIT_SIZE': 256}),
288- # triton.Config({'SPLIT_SIZE': 512}),
289- # # triton.Config({'SPLIT_SIZE': 1024}),
290- # # triton.Config({'SPLIT_SIZE': 2048}),
291- # # triton.Config({'SPLIT_SIZE': 4096}),
292- # # triton.Config({'SPLIT_SIZE': 8192}),
293- # # triton.Config({'SPLIT_SIZE': 16384}),
294- # ],
295- # key=['SPLIT_SIZE'],
296- # )
297227@triton .jit
298228def dequant_4bit_kernel (
299229 a_ptr , c_ptr , quant_ptr , absmax_ptr , num_paired_elements , QUANT_BLOCK : tl .constexpr , SPLIT_SIZE : tl .constexpr
0 commit comments