Skip to content

Commit dbd540d

Browse files
minor linting
Signed-off-by: cliu-us <[email protected]>
1 parent aba51f8 commit dbd540d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,12 @@ def isPowerofTwo(x):
473473
# because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
474474
# insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475475
# Do not support INT8 for now.
476-
if chunk_size == 8 and a.dtype in [torch.float8_e4m3fn, torch.float16, torch.bfloat16]:
477-
exp_ratio = min_chunk_size//chunk_size
476+
if chunk_size == 8 and a.dtype in [
477+
torch.float8_e4m3fn,
478+
torch.float16,
479+
torch.bfloat16,
480+
]:
481+
exp_ratio = min_chunk_size // chunk_size
478482
a_padded = torch.zeros(
479483
a.shape[0], a.shape[1] * exp_ratio, dtype=a.dtype, device=a.device
480484
)

0 commit comments

Comments
 (0)