Skip to content

Commit 39dbbae

Browse files
make triton optional if not available on system
Signed-off-by: cliu-us <[email protected]>
1 parent 2f0f780 commit 39dbbae

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

fms_mo/modules/linear.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
import torch.nn.functional as F
2828

2929
# Local
30-
from fms_mo.custom_ext_kernels.triton_kernels import (
31-
tl_matmul_chunk_truncate as tl_matmul,
32-
)
3330
from fms_mo.custom_ext_kernels.utils import pack_vectorized
3431
from fms_mo.quant.quantizers import (
3532
HardPrune,
@@ -39,6 +36,13 @@
3936
get_weight_quantizer,
4037
mask_fc_kij,
4138
)
39+
from fms_mo.utils.import_utils import available_packages
40+
41+
if available_packages["triton"]:
42+
# Local
43+
from fms_mo.custom_ext_kernels.triton_kernels import (
44+
tl_matmul_chunk_truncate as tl_matmul,
45+
)
4246

4347
logger = logging.getLogger(__name__)
4448

@@ -879,7 +883,9 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):
879883
qlinear_iW.nbits_w = 8
880884
qlinear_iW.acc_dtype = kwargs.get("acc_dtype", torch.float)
881885
qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True)
882-
qlinear_iW.use_int_kernel = kwargs.get("use_int_kernel", "triton")
886+
qlinear_iW.use_int_kernel = kwargs.get(
887+
"use_int_kernel", "triton" if available_packages["triton"] else False
888+
)
883889
qlinear_iW.weight = nn.Parameter(
884890
nnlin_iW.weight.to(torch.int8), requires_grad=False
885891
)
@@ -1127,7 +1133,7 @@ def set_matmul_op(self):
11271133
chunk_size=self.chunk_size,
11281134
)
11291135

1130-
elif self.use_int_kernel == "cutlass":
1136+
elif self.use_int_kernel == "cutlass" and available_packages["cutlass"]:
11311137
# will use real imatmul written in cutlass
11321138
cutlass_ops_load_and_reg()
11331139
# Third Party

fms_mo/utils/import_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"graphviz",
3030
"pygraphviz",
3131
"fms",
32+
"triton",
3233
]
3334

3435
available_packages = {}

tests/triton_kernels/test_triton_mm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
import torch
1919

2020
# Local
21-
from fms_mo.custom_ext_kernels.triton_kernels import (
22-
tl_matmul_chunk_truncate as tl_matmul,
23-
)
2421
from fms_mo.modules.linear import LinearFPxAcc
22+
from fms_mo.utils.import_utils import available_packages
23+
24+
if available_packages["triton"]:
25+
# Local
26+
from fms_mo.custom_ext_kernels.triton_kernels import (
27+
tl_matmul_chunk_truncate as tl_matmul,
28+
)
2529

2630

2731
@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096])

0 commit comments

Comments
 (0)