|
10 | 10 | from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
|
11 | 11 | from triton_kernels.numerics import InFlexData
|
12 | 12 | from triton_kernels.routing import routing
|
13 |
| -from triton_kernels.target_info import is_hip, get_cdna_version |
| 13 | +from triton_kernels.target_info import is_hip, get_cdna_version, is_cuda |
14 | 14 | from triton_kernels.tensor import convert_layout
|
15 | 15 | from triton_kernels.tensor_details.layout import StridedLayout, BlackwellMXScaleLayout, HopperMXScaleLayout, HopperMXValueLayout
|
16 | 16 | from triton_kernels.tensor import wrap_torch_tensor, FP4
|
@@ -101,14 +101,15 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
|
101 | 101 | optg = dict()
|
102 | 102 | opt1 = dict()
|
103 | 103 | opt2 = dict()
|
104 |
| - if w_dtype == "mx4" and not is_hip(): |
| 104 | + if w_dtype == "mx4": |
105 | 105 | value_layout = StridedLayout
|
106 | 106 | scale_layout = StridedLayout
|
107 |
| - if torch.cuda.get_device_capability()[0] == 9: |
108 |
| - value_layout = HopperMXValueLayout |
109 |
| - scale_layout = HopperMXScaleLayout |
110 |
| - if torch.cuda.get_device_capability()[0] == 10: |
111 |
| - scale_layout = BlackwellMXScaleLayout |
| 107 | + if is_cuda(): |
| 108 | + if torch.cuda.get_device_capability()[0] == 9: |
| 109 | + value_layout = HopperMXValueLayout |
| 110 | + scale_layout = HopperMXScaleLayout |
| 111 | + if torch.cuda.get_device_capability()[0] == 10: |
| 112 | + scale_layout = BlackwellMXScaleLayout |
112 | 113 | opt1 = {"value_layout": value_layout, "scale_layout": scale_layout}
|
113 | 114 | opt2 = deepcopy(opt1)
|
114 | 115 | wg, wg_flex, wg_scale = quantize(wg, "bf16", dev, **optg)
|
|
0 commit comments