Skip to content

Commit c086d08

Browse files
authored
Load scales instead of constexpr (#684)
Load scales from global memory. Scales are typically not provided as constexprs by users.
1 parent cd6f51b commit c086d08

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

python/perf-kernels/gemm.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def matmul_kernel(
5555
stride_bn,
5656
stride_cm,
5757
stride_cn,
58-
scale,
58+
a_scale_ptr,
59+
b_scale_ptr,
5960
# Meta-parameters
6061
BLOCK_SIZE_M: tl.constexpr,
6162
BLOCK_SIZE_N: tl.constexpr,
@@ -92,6 +93,9 @@ def matmul_kernel(
9293
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
9394
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
9495
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
96+
if APPLY_SCALE:
97+
a_scale = tl.load(a_scale_ptr)
98+
b_scale = tl.load(b_scale_ptr)
9599

96100
acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32
97101
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
@@ -110,12 +114,13 @@ def matmul_kernel(
110114
# Advance the ptrs to the next K block.
111115
a_ptrs += BLOCK_SIZE_K * stride_ak
112116
b_ptrs += BLOCK_SIZE_K * stride_bk
117+
# Apply scale to recover dynamic range reduced due to lower precision inputs.
118+
if APPLY_SCALE:
119+
accumulator = accumulator * a_scale * b_scale
113120
# Apply activation function, if specified.
121+
# TODO(vgokhale): Add different types of activations.
114122
if ACTIVATION == "leaky_relu":
115123
accumulator = leaky_relu(accumulator)
116-
# Apply scale to recover dynamic range reduced due to lower precision inputs.
117-
if APPLY_SCALE:
118-
accumulator = accumulator * scale
119124
c = accumulator.to(c_ptr.type.element_ty)
120125

121126
# Write back the block of the output matrix C with masks.
@@ -134,15 +139,13 @@ def leaky_relu(x):
134139

135140

136141
# Wrapper for gemm kernel.
137-
def matmul(a, b, c, a_scale, b_scale, activation=""):
142+
def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
138143
# Check constraints.
139144
assert a.shape[1] == b.shape[0], "Incompatible dimensions!!!"
140145
assert a.dtype == b.dtype, "Mixed dtype GEMMs are not supported!!!"
141146
M, K = a.shape
142147
K, N = b.shape
143148
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
144-
apply_scale = a_scale is not None and b_scale is not None
145-
scale = a_scale * b_scale if apply_scale else None
146149
matmul_kernel[grid](
147150
a,
148151
b,
@@ -156,8 +159,9 @@ def matmul(a, b, c, a_scale, b_scale, activation=""):
156159
b.stride(1),
157160
c.stride(0),
158161
c.stride(1),
159-
scale,
160-
APPLY_SCALE=apply_scale,
162+
a_scale,
163+
b_scale,
164+
APPLY_SCALE=scale_a8_b8,
161165
ACTIVATION=activation,
162166
)
163167

@@ -173,9 +177,12 @@ def matmul(a, b, c, a_scale, b_scale, activation=""):
173177
}
174178

175179
dtype_max = {
176-
torch.float8_e5m2fnuz: 57344,
177-
torch.float8_e4m3fnuz: 240,
178-
torch.int8: 127,
180+
dtype: (torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)).max
181+
for dtype in [
182+
torch.float8_e5m2fnuz,
183+
torch.float8_e4m3fnuz,
184+
torch.int8,
185+
]
179186
}
180187

181188

@@ -213,6 +220,7 @@ def get_x_vals():
213220

214221

215222
# Unit tests
223+
#TODO(vgokhale): Test activation.
216224
@pytest.mark.parametrize(
217225
"M, N, K, in_dtype, out_dtype, col_a, col_b",
218226
[(*shape, in_dtype, out_dtype, col_a, col_b)
@@ -232,12 +240,12 @@ def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype):
232240
# This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
233241
# If we use fp16 it is possible to return infs from the torch.matmul call.
234242
if dtype_is_8_bit(torch_in_dtype):
235-
matmul(a, b, c, a_scale.item(), b_scale.item(), activation="")
243+
matmul(a, b, c, a_scale, b_scale, scale_a8_b8=True, activation="")
236244
torch_output = torch.matmul(a_fp32, b_fp32)
237245
torch_output = torch_output * a_scale * b_scale
238246
# For other dtypes, use the same torch matmul as the dtype.
239247
else:
240-
matmul(a, b, c, a_scale=None, b_scale=None, activation="")
248+
matmul(a, b, c, a_scale=None, b_scale=None, scale_a8_b8=False, activation="")
241249
torch_output = torch.matmul(a.to(torch_in_dtype), b.to(torch_in_dtype))
242250
if out_dtype == 'int8':
243251
torch.testing.assert_close(c.to(torch.float32),

0 commit comments

Comments
 (0)