Skip to content

Commit 0f7201e

Browse files
enable chunk_size=8
Signed-off-by: cliu-us <[email protected]>
1 parent 2f0f780 commit 0f7201e

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,20 @@ def isPowerofTwo(x):
323323
return (x & (x - 1)) == 0
324324

325325
min_chunk_size = 32 if a.dtype in DTYPE_8BIT else 16
326-
if isPowerofTwo(chunk_size):
327-
chunk_size = max(chunk_size, min_chunk_size)
326+
327+
# because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
328+
# insert 0s in between elements, i.e. pad [m,k] -> [m,2k], [k,n]->[k,2n], out=[m,n] unchanged.
329+
# Do not support I8 or F8 for now. (as F8/FP24 simulation is treated as BF16 currently)
330+
if chunk_size == 8 and a.dtype in [torch.float16, torch.bfloat16]:
331+
a_padded = torch.zeros(a.shape[0], a.shape[1]*2, dtype=a.dtype, device=a.device)
332+
a_padded[:, ::2] = a
333+
a = a_padded
334+
b_padded = torch.zeros(b.shape[0]*2, b.shape[1], dtype=b.dtype, device=b.device)
335+
b_padded[::2, :] = b
336+
b = b_padded
337+
chunk_size = 16
328338
else:
329-
chunk_size = min_chunk_size
339+
chunk_size = max(chunk_size, min_chunk_size) if isPowerofTwo(chunk_size) else min_chunk_size
330340

331341
if a.dtype in DTYPE_I8:
332342
acc_dtype = torch.int32
@@ -345,7 +355,7 @@ def grid(META):
345355

346356
if M < 1024 or N < 1024:
347357
kernel_config = {
348-
"BLOCK_SIZE_M": 128,
358+
"BLOCK_SIZE_M": 64,
349359
"BLOCK_SIZE_K": chunk_size,
350360
"BLOCK_SIZE_N": 32,
351361
"GROUP_SIZE_M": 8,

fms_mo/modules/linear.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,7 +1797,7 @@ class LinearFuncFPxFwdBwd(torch.autograd.Function):
17971797
"""
17981798

17991799
@staticmethod
1800-
def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16):
1800+
def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False):
18011801
assert x.dtype in [torch.float, torch.bfloat16, torch.float16]
18021802
# input can be 2D or 3D, need to reshape before tl_matmul
18031803
org_dtype = x.dtype
@@ -1813,6 +1813,20 @@ def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16):
18131813
ctx.save_for_backward(x, weight) # x, W are saved in their original dtype
18141814
ctx.trun_bits = trun_bits
18151815
ctx.chunk_size = chunk_size
1816+
ctx.fp8_dyn = fp8_dyn
1817+
1818+
if fp8_dyn:
1819+
# use Q/dQ simulation for now, meaning still compute in fp16/bf16
1820+
# if choose per_token for input, use per_channel for W
1821+
# (W saved as [out, in], reduce inCh-dim, => reduce_dim=1)
1822+
ctx.fp8_e4m3_max = torch.finfo(torch.float8_e4m3fn).max
1823+
ctx.fp8_e5m2_max = torch.finfo(torch.float8_e5m2).max
1824+
reduce_dim = None if fp8_dyn == "per_tensor" else 1
1825+
x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max
1826+
w_scale = weight.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max
1827+
1828+
x = (x/x_scale).to(torch.float8_e4m3fn).to(org_dtype)*x_scale
1829+
weight = (weight/w_scale).to(torch.float8_e4m3fn).to(org_dtype)*w_scale
18161830

18171831
# triton kernel assumes 2D inputs and cast the return to input.dtype
18181832
output = tl_matmul(
@@ -1840,6 +1854,18 @@ def backward(ctx, grad_output):
18401854
target_shape_grad_input = grad_output.shape[:-1] + (in_dim,)
18411855
grad_output_2D = grad_output.reshape(-1, out_dim).to(dtype_input)
18421856

1857+
if ctx.fp8_dyn:
1858+
reduce_dim = None if ctx.fp8_dyn == "per_tensor" else 1
1859+
x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e5m2_max
1860+
w_scale = weight.abs().amax(dim=reduce_dim) / ctx.fp8_e5m2_max
1861+
grad_out_scale = grad_output_2D.abs().amax(dim=None) / ctx.fp8_e5m2_max # always perT
1862+
1863+
x = (x/x_scale).to(torch.float8_e5m2).to(dtype_input)*x_scale
1864+
weight = (weight/w_scale).to(torch.float8_e5m2).to(weight.dtype)*w_scale
1865+
grad_output_2D = (grad_output_2D/grad_out_scale).to(torch.float8_e5m2
1866+
).to(grad_output.dtype
1867+
)*grad_out_scale
1868+
18431869
# Compute grad_weight, shape = [out, in]
18441870
# NOTE: this triton kernel requires A matrix to be contiguous
18451871
grad_weight = tl_matmul(
@@ -1865,7 +1891,7 @@ def backward(ctx, grad_output):
18651891
else:
18661892
grad_bias = grad_output_2D.sum(0).to(ctx.bias_dtype)
18671893

1868-
return grad_input, grad_weight, grad_bias, None
1894+
return grad_input, grad_weight, grad_bias, None, None, None
18691895

18701896

18711897
class LinearFPxAcc(torch.nn.Linear):
@@ -1906,20 +1932,23 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs):
19061932

19071933
lin24acc.weight = nnlin.weight
19081934
lin24acc.trun_bits = trun_bits
1935+
lin24acc.chunk_size = kwargs.get("chunk_size", False)
1936+
lin24acc.fp8_dyn = kwargs.get("dynamic_fp8", False) #["per_tensor", "per_token"]
19091937

19101938
if nnlin.bias is not None:
19111939
lin24acc.bias = nnlin.bias
19121940
return lin24acc.to(target_device)
19131941

19141942
def forward(self, inputs):
19151943
# This Linear Class will cast to BF16 before matmul and return FP32
1916-
return LinearFuncFPxFwdBwd.apply(inputs, self.weight, self.bias, self.trun_bits)
1944+
return LinearFuncFPxFwdBwd.apply(inputs, self.weight, self.bias, self.trun_bits,
1945+
self.chunk_size, self.fp8_dyn)
19171946

19181947
def extra_repr(self) -> str:
19191948
"""
19201949
Returns an alternative string representation of the object.
19211950
"""
19221951
return (
19231952
f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, "
1924-
f"trun_bits={self.trun_bits}"
1953+
f"trun_bits={self.trun_bits},fp8_dyn={self.fp8_dyn},chunk_size={self.chunk_size}"
19251954
)

0 commit comments

Comments
 (0)