Skip to content

Commit 434e68f

Browse files
linting
Signed-off-by: cliu-us <[email protected]>
1 parent cf39d3d commit 434e68f

File tree

2 files changed

+36
-23
lines changed

2 files changed

+36
-23
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ def imatmul_kernel(
266266
else:
267267
accumulator = accumulator_inner
268268

269-
270269
a_ptrs += BLOCK_SIZE_K * stride_ak
271270
b_ptrs += BLOCK_SIZE_K * stride_bk
272271
if ACTIVATION == "leaky_relu":
@@ -281,7 +280,6 @@ def imatmul_kernel(
281280
tl.store(c_ptrs, c, mask=c_mask)
282281

283282

284-
285283
@triton.jit
286284
def matmul_kernel_DABC(
287285
# Pointers to matrices
@@ -311,11 +309,11 @@ def matmul_kernel_DABC(
311309
ACTIVATION: tl.constexpr,
312310
):
313311
"""Kernel for computing the matmul D = A x B + C that include LSB truncation.
314-
A has shape (M, K), B has shape (K, N) and C/D has shape (M, N).
312+
A has shape (M, K), B has shape (K, N) and C/D has shape (M, N).
315313
NOTE:
316314
C should be consistent with accumulator dtype, e.g. fp8xfp8 -> fp32.
317315
*D ptr is supposed to be the same as C ptr, no need to provide D as arg
318-
**we can be used C to verify unintended truncation by CUDA as well.
316+
**we can be used C to verify unintended truncation by CUDA as well.
319317
Args:
320318
chunk_trun_bits (int): number of LSB to truncate/round. [0 to 23]
321319
"""
@@ -353,9 +351,8 @@ def matmul_kernel_DABC(
353351
# -----------------------------------------------------------
354352
# Iterate to compute a block of the C matrix.
355353
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
356-
# of fp32 values for higher accuracy.
357-
# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
358-
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0) # should have been cast to fp32 already
354+
# of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already
355+
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0)
359356
## ------ prepare LSB rounding/truncation masks -------
360357
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
361358
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
@@ -477,15 +474,23 @@ def isPowerofTwo(x):
477474
# insert 0s in between elements, i.e. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
478475
# Do not support I8 or F8 for now. (as F8/FP24 simulation is treated as BF16 currently)
479476
if chunk_size == 8 and a.dtype in [torch.float16, torch.bfloat16]:
480-
a_padded = torch.zeros(a.shape[0], a.shape[1]*2, dtype=a.dtype, device=a.device)
477+
a_padded = torch.zeros(
478+
a.shape[0], a.shape[1] * 2, dtype=a.dtype, device=a.device
479+
)
481480
a_padded[:, ::2] = a
482481
a = a_padded
483-
b_padded = torch.zeros(b.shape[0]*2, b.shape[1], dtype=b.dtype, device=b.device)
482+
b_padded = torch.zeros(
483+
b.shape[0] * 2, b.shape[1], dtype=b.dtype, device=b.device
484+
)
484485
b_padded[::2, :] = b
485486
b = b_padded
486487
chunk_size = 16
487488
else:
488-
chunk_size = max(chunk_size, min_chunk_size) if isPowerofTwo(chunk_size) else min_chunk_size
489+
chunk_size = (
490+
max(chunk_size, min_chunk_size)
491+
if isPowerofTwo(chunk_size)
492+
else min_chunk_size
493+
)
489494

490495
M, K = a.shape
491496
K, N = b.shape
@@ -504,8 +509,8 @@ def isPowerofTwo(x):
504509
# if C is in fp16, accumulate in fp32 no matter what, decide whether to cast back later
505510
c_org_dtype = c.dtype
506511
c = c.to(acc_dtype)
507-
assert c.shape[0]==M and c.shape[1]==N, "C shape is inconsistent with A B."
508-
assert acc_dtype==torch.float32, "INT truncation experiment is not yet supported."
512+
assert c.shape[0] == M and c.shape[1] == N, "C shape is inconsistent with A B."
513+
assert acc_dtype == torch.float32, "INT truncation is not yet supported."
509514

510515
# 1D launch kernel where each block gets its own program.
511516
def grid(META):

fms_mo/modules/linear.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,8 +1825,8 @@ def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False
18251825
x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max
18261826
w_scale = weight.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max
18271827

1828-
x = (x/x_scale).to(torch.float8_e4m3fn).to(org_dtype)*x_scale
1829-
weight = (weight/w_scale).to(torch.float8_e4m3fn).to(org_dtype)*w_scale
1828+
x = (x / x_scale).to(torch.float8_e4m3fn).to(org_dtype) * x_scale
1829+
weight = (weight / w_scale).to(torch.float8_e4m3fn).to(org_dtype) * w_scale
18301830

18311831
# triton kernel assumes 2D inputs and cast the return to input.dtype
18321832
output = tl_matmul(
@@ -1858,13 +1858,14 @@ def backward(ctx, grad_output):
18581858
reduce_dim = None if ctx.fp8_dyn == "per_tensor" else 1
18591859
x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e5m2_max
18601860
w_scale = weight.abs().amax(dim=reduce_dim) / ctx.fp8_e5m2_max
1861-
grad_out_scale = grad_output_2D.abs().amax(dim=None) / ctx.fp8_e5m2_max # always perT
1861+
# always assume perT in this case
1862+
grad_out_scale = grad_output_2D.abs().amax(dim=None) / ctx.fp8_e5m2_max
18621863

1863-
x = (x/x_scale).to(torch.float8_e5m2).to(dtype_input)*x_scale
1864-
weight = (weight/w_scale).to(torch.float8_e5m2).to(weight.dtype)*w_scale
1865-
grad_output_2D = (grad_output_2D/grad_out_scale).to(torch.float8_e5m2
1866-
).to(grad_output.dtype
1867-
)*grad_out_scale
1864+
x = (x / x_scale).to(torch.float8_e5m2).to(dtype_input) * x_scale
1865+
weight = (weight / w_scale).to(torch.float8_e5m2).to(weight.dtype) * w_scale
1866+
grad_output_2D = (grad_output_2D / grad_out_scale).to(torch.float8_e5m2).to(
1867+
grad_output.dtype
1868+
) * grad_out_scale
18681869

18691870
# Compute grad_weight, shape = [out, in]
18701871
# NOTE: this triton kernel requires A matrix to be contiguous
@@ -1933,16 +1934,23 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs):
19331934
lin24acc.weight = nnlin.weight
19341935
lin24acc.trun_bits = trun_bits
19351936
lin24acc.chunk_size = kwargs.get("chunk_size", False)
1936-
lin24acc.fp8_dyn = kwargs.get("dynamic_fp8", False) #["per_tensor", "per_token"]
1937+
lin24acc.fp8_dyn = kwargs.get("dynamic_fp8", False)
1938+
# available options are ["per_tensor", "per_token"]
19371939

19381940
if nnlin.bias is not None:
19391941
lin24acc.bias = nnlin.bias
19401942
return lin24acc.to(target_device)
19411943

19421944
def forward(self, inputs):
19431945
# This Linear Class will cast to BF16 before matmul and return FP32
1944-
return LinearFuncFPxFwdBwd.apply(inputs, self.weight, self.bias, self.trun_bits,
1945-
self.chunk_size, self.fp8_dyn)
1946+
return LinearFuncFPxFwdBwd.apply(
1947+
inputs,
1948+
self.weight,
1949+
self.bias,
1950+
self.trun_bits,
1951+
self.chunk_size,
1952+
self.fp8_dyn,
1953+
)
19461954

19471955
def extra_repr(self) -> str:
19481956
"""

0 commit comments

Comments
 (0)