We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 81fe1bd commit c7cd06cCopy full SHA for c7cd06c
fms_mo/custom_ext_kernels/utils.py
@@ -873,6 +873,7 @@ def lower_qmodel_triton(
873
clamp_acc_to_dl16=False,
874
num_lsb_to_truncate=0,
875
chunk_size=32,
876
+ layer_to_exclude=[],
877
):
878
"""
879
Examplar GPU lowering function using triton. Only swap Linear/Qlinear in transformers.
@@ -916,7 +917,7 @@ def lower_qmodel_triton(
916
917
)
918
919
for name, m in model.named_modules():
- if not isinstance(m, (QLinear, torch.nn.Linear)):
920
+ if not isinstance(m, (QLinear, torch.nn.Linear)) or name in layer_to_exclude:
921
continue
922
parent_name, module_name = _parent_name(name)
923
parent_mod = model.get_submodule(parent_name)
0 commit comments