Skip to content

Commit b5f84d9

Browse files
Merge pull request #78 from chichun-charlie-liu/main
fix: make triton optional for systems without GPUs
2 parents 395b7f9 + 3615939 commit b5f84d9

File tree

6 files changed

+29
-12
lines changed

6 files changed

+29
-12
lines changed

fms_mo/modules/linear.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
import torch.nn.functional as F
2828

2929
# Local
30-
from fms_mo.custom_ext_kernels.triton_kernels import (
31-
tl_matmul_chunk_truncate as tl_matmul,
32-
)
3330
from fms_mo.custom_ext_kernels.utils import pack_vectorized
3431
from fms_mo.quant.quantizers import (
3532
HardPrune,
@@ -39,6 +36,13 @@
3936
get_weight_quantizer,
4037
mask_fc_kij,
4138
)
39+
from fms_mo.utils.import_utils import available_packages
40+
41+
if available_packages["triton"]:
42+
# Local
43+
from fms_mo.custom_ext_kernels.triton_kernels import (
44+
tl_matmul_chunk_truncate as tl_matmul,
45+
)
4246

4347
logger = logging.getLogger(__name__)
4448

@@ -879,7 +883,9 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):
879883
qlinear_iW.nbits_w = 8
880884
qlinear_iW.acc_dtype = kwargs.get("acc_dtype", torch.float)
881885
qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True)
882-
qlinear_iW.use_int_kernel = kwargs.get("use_int_kernel", "triton")
886+
qlinear_iW.use_int_kernel = kwargs.get(
887+
"use_int_kernel", "triton" if available_packages["triton"] else False
888+
)
883889
qlinear_iW.weight = nn.Parameter(
884890
nnlin_iW.weight.to(torch.int8), requires_grad=False
885891
)
@@ -1119,15 +1125,15 @@ def set_matmul_op(self):
11191125
imatmul_ops_reg,
11201126
)
11211127

1122-
if self.use_int_kernel == "triton":
1128+
if self.use_int_kernel == "triton" and available_packages["triton"]:
11231129
# will use real imatmul written in triton
11241130
imm_func = partial(
11251131
tl_matmul,
11261132
chunk_trun_bits=self.truncate_lsb,
11271133
chunk_size=self.chunk_size,
11281134
)
11291135

1130-
elif self.use_int_kernel == "cutlass":
1136+
elif self.use_int_kernel == "cutlass" and available_packages["cutlass"]:
11311137
# will use real imatmul written in cutlass
11321138
cutlass_ops_load_and_reg()
11331139
# Third Party

fms_mo/run_quant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def quantize(
9292
"auto_gptq module not found. For more instructions on installing the appropriate "
9393
"package, see https://github.com/AutoGPTQ/AutoGPTQ?tab=readme-ov-file#installation"
9494
)
95+
gptq_args.use_triton = gptq_args.use_triton and available_packages["triton"]
9596
run_gptq(model_args, data_args, opt_args, gptq_args)
9697
elif opt_args.quant_method == "fp8":
9798
if not available_packages["llmcompressor"]:

fms_mo/utils/import_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"graphviz",
3030
"pygraphviz",
3131
"fms",
32+
"triton",
3233
]
3334

3435
available_packages = {}

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ dependencies = [
3737
"huggingface_hub",
3838
"pandas",
3939
"safetensors",
40-
"ibm-fms>=0.0.8"
40+
"ibm-fms>=0.0.8",
41+
"pkginfo>1.10"
4142
]
4243

4344
[project.optional-dependencies]

tests/models/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ def input_bert():
10791079
torch.FloatTensor: BERT sample input
10801080
"""
10811081
text = "Replace me by any text you'd like."
1082-
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
1082+
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
10831083
return tokenizer(text, return_tensors="pt")
10841084

10851085

@@ -1091,4 +1091,4 @@ def model_bert():
10911091
Returns:
10921092
transformers.models.bert.modeling_bert.BertModel: BERT model
10931093
"""
1094-
return BertModel.from_pretrained("bert-base-uncased", torchscript=True)
1094+
return BertModel.from_pretrained("google-bert/bert-base-uncased", torchscript=True)

tests/triton_kernels/test_triton_mm.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,18 @@
1818
import torch
1919

2020
# Local
21-
from fms_mo.custom_ext_kernels.triton_kernels import (
22-
tl_matmul_chunk_truncate as tl_matmul,
23-
)
2421
from fms_mo.modules.linear import LinearFPxAcc
22+
from fms_mo.utils.import_utils import available_packages
23+
24+
if available_packages["triton"]:
25+
# Local
26+
from fms_mo.custom_ext_kernels.triton_kernels import (
27+
tl_matmul_chunk_truncate as tl_matmul,
28+
)
29+
else:
30+
raise ImportError(
31+
"triton python package is not avaialble, please check your installation."
32+
)
2533

2634

2735
@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096])

0 commit comments

Comments
 (0)