Skip to content

Commit c7cd06c

Browse files
lower_qmodel_triton() can skip layers if needed
Signed-off-by: cliu-us <[email protected]>
1 parent 81fe1bd commit c7cd06c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

fms_mo/custom_ext_kernels/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,7 @@ def lower_qmodel_triton(
873873
clamp_acc_to_dl16=False,
874874
num_lsb_to_truncate=0,
875875
chunk_size=32,
876+
layer_to_exclude=[],
876877
):
877878
"""
878879
Examplar GPU lowering function using triton. Only swap Linear/Qlinear in transformers.
@@ -916,7 +917,7 @@ def lower_qmodel_triton(
916917
)
917918

918919
for name, m in model.named_modules():
919-
if not isinstance(m, (QLinear, torch.nn.Linear)):
920+
if not isinstance(m, (QLinear, torch.nn.Linear)) or name in layer_to_exclude:
920921
continue
921922
parent_name, module_name = _parent_name(name)
922923
parent_mod = model.get_submodule(parent_name)

0 commit comments

Comments
 (0)