Skip to content

Commit 3df0da5

Browse files
authored
[KERNELS] Fix benchmarking fp8 on hopper (#7629)
Weight used by `matmul_og` must be in column-major order on for hopper and earlier architectures with the fp8 type
1 parent 2e359d3 commit 3df0da5

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
1111
from triton_kernels.numerics import InFlexData
1212
from triton_kernels.routing import routing
13-
from triton_kernels.target_info import is_hip, get_cdna_version
13+
from triton_kernels.target_info import is_cuda, is_hip, get_cdna_version, cuda_capability_geq
1414
from triton_kernels.tensor import convert_layout
1515
from triton_kernels.tensor import wrap_torch_tensor, FP4
1616
from dataclasses import dataclass
@@ -32,6 +32,8 @@ def quantize(w, dtype, **opt):
3232
fp8e4_dtype = torch.float8_e4m3fn if get_cdna_version() != 3 \
3333
else torch.float8_e4m3fnuz
3434
wq = w.to(fp8e4_dtype)
35+
if is_cuda() and not cuda_capability_geq(10, 0):
36+
wq = wq.transpose(-1, -2).contiguous().transpose(-1, -2)
3537
return wq, InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), None
3638
else:
3739
assert dtype == "mx4", f"{dtype=}"

0 commit comments

Comments
 (0)