|
| 1 | +import torch |
| 2 | + |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | + |
| 6 | + |
| 7 | +# @triton.autotune( |
| 8 | +# configs=[ |
| 9 | +# # triton.Config({'SPLIT_SIZE': 64}), |
| 10 | +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), |
| 11 | +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), |
| 12 | +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), |
| 13 | +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), |
| 14 | +# # triton.Config({'SPLIT_SIZE': 128}), |
| 15 | +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), |
| 16 | +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), |
| 17 | +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), |
| 18 | +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), |
| 19 | +# triton.Config({"SPLIT_SIZE": 256}), |
| 20 | +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32), |
| 21 | +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), |
| 22 | +# triton.Config({"SPLIT_SIZE": 512}), |
| 23 | +# # triton.Config({'SPLIT_SIZE': 1024}), |
| 24 | +# ], |
| 25 | +# key=["num_paired_elements", "QUANT_BLOCK"], |
| 26 | +# ) |
| 27 | +@triton.jit |
| 28 | +def dequant_8bit_kernel( |
| 29 | + a_ptr, |
| 30 | + out_ptr, |
| 31 | + code_ptr, |
| 32 | + absmax_ptr, |
| 33 | + n, |
| 34 | + QUANT_BLOCK: tl.constexpr, |
| 35 | + SPLIT_SIZE: tl.constexpr, |
| 36 | +): |
| 37 | + pid = tl.program_id(axis=0) |
| 38 | + block_start = pid * SPLIT_SIZE |
| 39 | + offsets = block_start + tl.arange(0, SPLIT_SIZE) |
| 40 | + mask = offsets < n |
| 41 | + out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK) |
| 42 | + tl.store(out_ptr + offsets, out_dq, mask) |
| 43 | + |
| 44 | + |
| 45 | +def dequant_8bit_blockwise( |
| 46 | + a: torch.Tensor, |
| 47 | + absmax: torch.Tensor, |
| 48 | + quant_state_code: torch.Tensor, |
| 49 | + quant_blocksize: int = 64, |
| 50 | + dtype: torch.dtype = None, |
| 51 | + out: torch.Tensor = None, |
| 52 | +): |
| 53 | + n = a.numel() |
| 54 | + if out is None: |
| 55 | + if dtype is None: |
| 56 | + raise ValueError("If out is None, dtype must be specified") |
| 57 | + out = torch.empty_like(a, dtype=dtype, device=a.device) |
| 58 | + |
| 59 | + SPLIT_SIZE = 256 |
| 60 | + # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) |
| 61 | + grid = (triton.cdiv(n, SPLIT_SIZE),) |
| 62 | + dequant_8bit_kernel[grid]( |
| 63 | + a, |
| 64 | + out, |
| 65 | + quant_state_code, |
| 66 | + absmax, |
| 67 | + n, |
| 68 | + quant_blocksize, |
| 69 | + SPLIT_SIZE, |
| 70 | + ) |
| 71 | + return out |
| 72 | + |
| 73 | + |
| 74 | +# @triton.autotune( |
| 75 | +# configs=[ |
| 76 | +# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), |
| 77 | +# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), |
| 78 | +# triton.Config({"SPLIT_NUM_BLOCKS": 1}), |
| 79 | +# triton.Config({"SPLIT_NUM_BLOCKS": 2}), |
| 80 | +# ], |
| 81 | +# key=["n_elements"], |
| 82 | +# ) |
| 83 | +@triton.jit |
| 84 | +def quantize_8bit_blockwise_kernel( |
| 85 | + A_ptr, |
| 86 | + code_ptr, |
| 87 | + absmax_ptr, |
| 88 | + out_ptr, |
| 89 | + n_elements, |
| 90 | + BLOCK_SIZE: tl.constexpr, |
| 91 | + CODE_SIZE: tl.constexpr, |
| 92 | + SPLIT_NUM_BLOCKS: tl.constexpr, |
| 93 | +): |
| 94 | + block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS |
| 95 | + thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) |
| 96 | + |
| 97 | + offsets = block_start_idx * BLOCK_SIZE + thread_idx |
| 98 | + mask = offsets < n_elements |
| 99 | + |
| 100 | + A = tl.load(A_ptr + offsets, mask=mask, other=0.0) |
| 101 | + |
| 102 | + quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS) |
| 103 | + tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) |
| 104 | + tl.store(out_ptr + offsets, quantized, mask=mask) |
| 105 | + |
| 106 | + |
| 107 | +def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None): |
| 108 | + n = A.numel() |
| 109 | + blocks = -(n // -blocksize) |
| 110 | + |
| 111 | + if absmax is None: |
| 112 | + absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) |
| 113 | + if out is None: |
| 114 | + out = torch.empty_like(A.flatten(), dtype=torch.uint8) |
| 115 | + |
| 116 | + split_num_blocks = 1 |
| 117 | + grid = (triton.cdiv(blocks, split_num_blocks),) |
| 118 | + # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) |
| 119 | + quantize_8bit_blockwise_kernel[grid]( |
| 120 | + A_ptr=A, |
| 121 | + code_ptr=code, |
| 122 | + absmax_ptr=absmax, |
| 123 | + out_ptr=out, |
| 124 | + n_elements=n, |
| 125 | + BLOCK_SIZE=blocksize, |
| 126 | + CODE_SIZE=code.numel(), |
| 127 | + SPLIT_NUM_BLOCKS=split_num_blocks, |
| 128 | + # num_warps=1, |
| 129 | + # num_stages=2, |
| 130 | + ) |
| 131 | + out = out.reshape(A.shape) |
| 132 | + |
| 133 | + return out, absmax |
| 134 | + |
| 135 | + |
| 136 | +@triton.jit |
| 137 | +def quantize_8bit_blockwise_kernel_util( |
| 138 | + a, |
| 139 | + code_ptr, |
| 140 | + CODE_SIZE: tl.constexpr, |
| 141 | + BLOCK_SIZE: tl.constexpr, |
| 142 | + N_PER_TH: tl.constexpr, |
| 143 | +): |
| 144 | + # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) |
| 145 | + a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE)) |
| 146 | + |
| 147 | + # Calculating absmax for each block |
| 148 | + absmax = tl.max(tl.abs(a_reshaped), axis=1) |
| 149 | + |
| 150 | + a_normalized = a_reshaped / absmax[:, None] |
| 151 | + a_normalized = tl.clamp(a_normalized, -1.0, 1.0) |
| 152 | + |
| 153 | + lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32) |
| 154 | + upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) |
| 155 | + |
| 156 | + # ceil(log2(code_size)) = 8, actually, in general case should be input parameter |
| 157 | + for _ in range(8): |
| 158 | + pivot = (lower_pivot + upper_pivot) // 2 |
| 159 | + val = tl.load(code_ptr + pivot) |
| 160 | + is_higher = a_normalized > val # code[pivot] |
| 161 | + lower_pivot = tl.where(is_higher, pivot, lower_pivot) |
| 162 | + upper_pivot = tl.where(is_higher, upper_pivot, pivot) |
| 163 | + |
| 164 | + # Choose closest level |
| 165 | + lower_val = tl.load(code_ptr + lower_pivot) |
| 166 | + upper_val = tl.load(code_ptr + upper_pivot) |
| 167 | + lower_dist = tl.abs(a_normalized - lower_val) |
| 168 | + upper_dist = tl.abs(a_normalized - upper_val) |
| 169 | + quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) |
| 170 | + |
| 171 | + # too slow approach |
| 172 | + # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) |
| 173 | + # quantized = tl.argmin(diff, axis=2).to(tl.uint8) |
| 174 | + |
| 175 | + quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,)) |
| 176 | + return quantized_flat, absmax |
| 177 | + |
| 178 | + |
| 179 | +@triton.jit |
| 180 | +def dequant_8bit_blockwise_kernel_util( |
| 181 | + a_ptr, |
| 182 | + offsets, |
| 183 | + code_ptr, |
| 184 | + absmax_ptr, |
| 185 | + mask, |
| 186 | + BLOCK_SIZE: tl.constexpr, |
| 187 | +): |
| 188 | + a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8) |
| 189 | + scaled_int8 = tl.load(code_ptr + a, mask) |
| 190 | + # Load scales |
| 191 | + absmax_offsets = offsets // BLOCK_SIZE |
| 192 | + absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last") |
| 193 | + # Apply scales |
| 194 | + out_dq = scaled_int8 * absmax |
| 195 | + return out_dq |
0 commit comments