|
27 | 27 | import torch.nn.functional as F |
28 | 28 |
|
29 | 29 | # Local |
30 | | -from fms_mo.custom_ext_kernels.triton_kernels import ( |
31 | | - tl_matmul_chunk_truncate as tl_matmul, |
32 | | -) |
33 | 30 | from fms_mo.custom_ext_kernels.utils import pack_vectorized |
34 | 31 | from fms_mo.quant.quantizers import ( |
35 | 32 | HardPrune, |
|
39 | 36 | get_weight_quantizer, |
40 | 37 | mask_fc_kij, |
41 | 38 | ) |
| 39 | +from fms_mo.utils.import_utils import available_packages |
| 40 | + |
| 41 | +if available_packages["triton"]: |
| 42 | + # Local |
| 43 | + from fms_mo.custom_ext_kernels.triton_kernels import ( |
| 44 | + tl_matmul_chunk_truncate as tl_matmul, |
| 45 | + ) |
42 | 46 |
|
43 | 47 | logger = logging.getLogger(__name__) |
44 | 48 |
|
@@ -879,7 +883,9 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs): |
879 | 883 | qlinear_iW.nbits_w = 8 |
880 | 884 | qlinear_iW.acc_dtype = kwargs.get("acc_dtype", torch.float) |
881 | 885 | qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True) |
882 | | - qlinear_iW.use_int_kernel = kwargs.get("use_int_kernel", "triton") |
| 886 | + qlinear_iW.use_int_kernel = kwargs.get( |
| 887 | + "use_int_kernel", "triton" if available_packages["triton"] else False |
| 888 | + ) |
883 | 889 | qlinear_iW.weight = nn.Parameter( |
884 | 890 | nnlin_iW.weight.to(torch.int8), requires_grad=False |
885 | 891 | ) |
@@ -1119,15 +1125,15 @@ def set_matmul_op(self): |
1119 | 1125 | imatmul_ops_reg, |
1120 | 1126 | ) |
1121 | 1127 |
|
1122 | | - if self.use_int_kernel == "triton": |
| 1128 | + if self.use_int_kernel == "triton" and available_packages["triton"]: |
1123 | 1129 | # will use real imatmul written in triton |
1124 | 1130 | imm_func = partial( |
1125 | 1131 | tl_matmul, |
1126 | 1132 | chunk_trun_bits=self.truncate_lsb, |
1127 | 1133 | chunk_size=self.chunk_size, |
1128 | 1134 | ) |
1129 | 1135 |
|
1130 | | - elif self.use_int_kernel == "cutlass": |
| 1136 | + elif self.use_int_kernel == "cutlass" and available_packages["cutlass"]: |
1131 | 1137 | # will use real imatmul written in cutlass |
1132 | 1138 | cutlass_ops_load_and_reg() |
1133 | 1139 | # Third Party |
|
0 commit comments