-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathint8_fused_kernel.py
More file actions
372 lines (303 loc) · 13.3 KB
/
int8_fused_kernel.py
File metadata and controls
372 lines (303 loc) · 13.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice
# =============================================================================
# Kernel 1: Fused Row-wise Quantization (FP16/BF16 -> INT8 + Scale)
# =============================================================================
@triton.jit
def _quantize_rowwise_kernel(
x_ptr, # Input pointer (FP16/BF16)
y_ptr, # Output pointer (INT8)
s_ptr, # Scale pointer (FP32)
n_elements, # Number of columns
BLOCK_SIZE: tl.constexpr,
):
# Row index we are processing
row_idx = tl.program_id(0)
# Pointers to the start of the row
x_row_ptr = x_ptr + row_idx * n_elements
y_row_ptr = y_ptr + row_idx * n_elements
# 1. Compute Max Abs Value for the row
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load data
x = tl.load(x_row_ptr + offsets, mask=mask, other=0.0)
# Absolute value
abs_x = tl.abs(x)
# Reduction to find max
max_val = tl.max(abs_x, axis=0)
# 2. Compute Scale
# scale = max_val / 127.0
scale = tl.maximum(max_val / 127.0, 1e-30)
# 3. Quantize
# q = x / scale
q_f = x / scale
# Round and Clamp
# FIX: Use floor(x + 0.5) for rounding. This is portable across Triton versions.
q_i = libdevice.rint(q_f).to(tl.int32)
q_i = tl.clamp(q_i, -128.0, 127.0)
# 4. Store
tl.store(y_row_ptr + offsets, q_i.to(tl.int8), mask=mask)
tl.store(s_ptr + row_idx, scale.to(tl.float32))
def triton_quantize_rowwise(x: torch.Tensor):
"""
Input: [Batch, Dim] (float16/bfloat16/float32)
Output: [Batch, Dim] (int8), [Batch, 1] (float32)
"""
rows, cols = x.shape
y = torch.empty_like(x, dtype=torch.int8)
s = torch.empty((rows, 1), device=x.device, dtype=torch.float32)
# Heuristic for block size
BLOCK_SIZE = triton.next_power_of_2(cols)
if BLOCK_SIZE < 128: BLOCK_SIZE = 128
# Note: If cols > BLOCK_SIZE (e.g. > 8192 usually), this naive block logic needs a loop.
# For Flux2 Klein, Z-Image, Chroma layers this appears fine afaik.
grid = (rows,)
_quantize_rowwise_kernel[grid](x, y, s, cols, BLOCK_SIZE=BLOCK_SIZE)
return y, s
# =============================================================================
# Kernel 2: INT8 GEMM + Fused Dequantization Epilogue
# =============================================================================
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def _int8_matmul_dequant_kernel(
# Pointers
a_ptr, b_ptr, c_ptr,
a_scale_ptr, b_scale_ptr, bias_ptr,
# Matrix Dimensions
M, N, K,
# Strides
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
HAS_BIAS: tl.constexpr
):
"""
Computes: C = ((A * B) * (scale_a * scale_b)) + bias
A: [M, K] int8
B: [N, K] int8 (Transposed physically or logically via strides)
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# 1. Prepare Pointers for A and B
# A block pointer: [BLOCK_M, BLOCK_K]
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# 2. Main Loop (Accumulate in Int32)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Load chunks
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
# Matrix Multiply (Int8 inputs -> Int32 accum)
accumulator += tl.dot(a, b)
# Advance pointers
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# 3. Fused Epilogue (Dequantize & Bias)
# Load dynamic scales
# A Scale is per-row [M, 1]
scale_a = tl.load(a_scale_ptr + offs_am) # Vector [BLOCK_M]
# B Scale is scalar or tensor.
scale_b = tl.load(b_scale_ptr)
# Convert Accumulator to Float
c = accumulator.to(tl.float32)
# Combine scales: scale_a (broadcast columns) * scale_b
total_scale = scale_a[:, None] * scale_b
c = c * total_scale
# Add Bias if present
if HAS_BIAS:
bias = tl.load(bias_ptr + offs_bn) # Vector [BLOCK_N]
c = c + bias[None, :]
# 4. Store Result (Cast to output dtype, usually FP16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
# We write as fp16 or bf16 implicitly by the pointer type, but explicit cast is safer
tl.store(c_ptrs, c, mask=c_mask)
# =============================================================================
# Python Wrapper
# =============================================================================
def triton_int8_linear(x: torch.Tensor, weight: torch.Tensor, weight_scale, bias=None, compute_dtype=torch.float16):
"""
Fused pipeline for W8A8 Linear Layer.
"""
# 1. Flatten inputs if 3D [Batch, Tokens, Dim] -> [Batch*Tokens, Dim]
x_shape_orig = x.shape
x_2d = x.reshape(-1, x_shape_orig[-1])
M, K = x_2d.shape
N = weight.shape[0]
# 2. Kernel 1: Dynamic Activation Quantization
# (This is much faster than Python-loop based axiswise quant)
x_int8, x_scale = triton_quantize_rowwise(x_2d)
# 3. Allocate Output
output = torch.empty((M, N), device=x.device, dtype=compute_dtype)
# 4. Prepare Scales for Kernel
# Ensure weight_scale is a tensor on device
if not isinstance(weight_scale, torch.Tensor):
weight_scale = torch.tensor([weight_scale], device=x.device, dtype=torch.float32)
elif weight_scale.numel() == 1:
weight_scale = weight_scale.reshape(1)
# 5. Kernel 2: Fused GEMM + Dequant
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
# Check if we have bias
has_bias = bias is not None
bias_ptr = bias if has_bias else x # Dummy pointer if None
# NOTE: PyTorch Linear weights are [Out, In] (N, K).
# The kernel expects B to be [K, N] logically.
# Since weight is [N, K], we can treat it as [K, N] TRANSPOSED.
# Stride of W is [K, 1]. To read as column-major [K, N], stride is [1, K].
_int8_matmul_dequant_kernel[grid](
# Pointers
a_ptr=x_int8,
b_ptr=weight,
c_ptr=output,
a_scale_ptr=x_scale,
b_scale_ptr=weight_scale,
bias_ptr=bias_ptr,
# Shapes
M=M, N=N, K=K,
# Strides
stride_am=x_int8.stride(0), stride_ak=x_int8.stride(1),
stride_bk=weight.stride(1), stride_bn=weight.stride(0), # Transposed access of W
stride_cm=output.stride(0), stride_cn=output.stride(1),
# Meta
HAS_BIAS=has_bias
)
# 6. Reshape output
return output.reshape(x_shape_orig[:-1] + (N,))
# =============================================================================
# Kernel 3: INT8 GEMM + Fused Dequant with Per-Row Weight Scales
# =============================================================================
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def _int8_matmul_dequant_per_row_kernel(
# Pointers
a_ptr, b_ptr, c_ptr,
a_scale_ptr, b_scale_ptr, bias_ptr,
# Matrix Dimensions
M, N, K,
# Strides
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
HAS_BIAS: tl.constexpr
):
"""
Computes: C = ((A * B) * (scale_a[:, None] * scale_b[None, :])) + bias
A: [M, K] int8, scale_a: [M, 1] per-row activation scales
B: [N, K] int8, scale_b: [N, 1] per-row weight scales
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# 1. Prepare Pointers for A and B
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# 2. Main Loop (Accumulate in Int32)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# 3. Fused Epilogue (Dequantize & Bias)
# A Scale is per-row [M, 1]
scale_a = tl.load(a_scale_ptr + offs_am) # Vector [BLOCK_M]
# B Scale is per-row [N, 1] (the key difference from the scalar kernel)
scale_b = tl.load(b_scale_ptr + offs_bn) # Vector [BLOCK_N]
c = accumulator.to(tl.float32)
# Outer product of scales: [BLOCK_M, 1] * [1, BLOCK_N]
total_scale = scale_a[:, None] * scale_b[None, :]
c = c * total_scale
if HAS_BIAS:
bias = tl.load(bias_ptr + offs_bn)
c = c + bias[None, :]
# 4. Store Result
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# =============================================================================
# Python Wrapper (Per-Row Weight Scales)
# =============================================================================
def triton_int8_linear_per_row(x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias=None, compute_dtype=torch.float16):
"""
Fused pipeline for W8A8 Linear Layer with per-row weight quantization.
weight_scale: [N, 1] per-row scales
"""
# 1. Flatten inputs if 3D
x_shape_orig = x.shape
x_2d = x.reshape(-1, x_shape_orig[-1])
M, K = x_2d.shape
N = weight.shape[0]
# 2. Dynamic Activation Quantization
x_int8, x_scale = triton_quantize_rowwise(x_2d)
# 3. Allocate Output
output = torch.empty((M, N), device=x.device, dtype=compute_dtype)
# 4. Prepare weight scales - flatten [N, 1] -> [N] for kernel
ws = weight_scale.reshape(N).contiguous()
# 5. Fused GEMM + Per-Row Dequant
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
has_bias = bias is not None
bias_ptr = bias if has_bias else x # Dummy pointer if None
_int8_matmul_dequant_per_row_kernel[grid](
a_ptr=x_int8,
b_ptr=weight,
c_ptr=output,
a_scale_ptr=x_scale,
b_scale_ptr=ws,
bias_ptr=bias_ptr,
M=M, N=N, K=K,
stride_am=x_int8.stride(0), stride_ak=x_int8.stride(1),
stride_bk=weight.stride(1), stride_bn=weight.stride(0),
stride_cm=output.stride(0), stride_cn=output.stride(1),
HAS_BIAS=has_bias
)
# 6. Reshape output
return output.reshape(x_shape_orig[:-1] + (N,))