Skip to content

Commit 03cdcdb

Browse files
authored
[KERNELS] Fix bench_mlp.py for AMD (#7600)
`opt` is empty for if the testing platform is rocm. I suppose we don't do layout conversion in this case.
1 parent ef72c31 commit 03cdcdb

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
cublas = None
2525

2626

27-
def quantize(w, dtype, dev, **opt):
27+
def quantize(w, dtype, **opt):
2828
if dtype == "bf16":
2929
wq = w.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2)
3030
return wq, InFlexData(), None
@@ -36,8 +36,9 @@ def quantize(w, dtype, dev, **opt):
3636
else:
3737
assert dtype == "mx4", f"{dtype=}"
3838
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
39-
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
40-
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
39+
if opt:
40+
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
41+
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
4142
return w, InFlexData(), w_scale
4243

4344

@@ -109,9 +110,9 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
109110
opt1 = {"value_layout": value_layout, "value_layout_opts": value_layout_opts, \
110111
"scale_layout": scale_layout, "scale_layout_opts": scale_layout_opts}
111112
opt2 = deepcopy(opt1)
112-
wg, wg_flex, wg_scale = quantize(wg, "bf16", dev, **optg)
113-
w1, w1_flex, w1_scale = quantize(w1, w_dtype, dev, **opt1)
114-
w2, w2_flex, w2_scale = quantize(w2, w_dtype, dev, **opt2)
113+
wg, wg_flex, wg_scale = quantize(wg, "bf16", **optg)
114+
w1, w1_flex, w1_scale = quantize(w1, w_dtype, **opt1)
115+
w2, w2_flex, w2_scale = quantize(w2, w_dtype, **opt2)
115116
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), weight_scale=wg_scale)
116117
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.0, 1.0), 2)
117118
pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), weight_scale=w1_scale)

0 commit comments

Comments
 (0)