Skip to content

Commit 9cf8f65

Browse files
committed
fix all models
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent de15a0f commit 9cf8f65

File tree

4 files changed

+53
-13
lines changed

4 files changed

+53
-13
lines changed

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
# consider rewriting register_foak_model_patch_rules into something
3131
# like this also
32-
def register_foak_model_patch_rules2(
32+
def register_foak_model_patch_rules(
3333
base_type: str,
3434
filter_endswith: Set[str] = None,
3535
config: PretrainedConfig = None,
@@ -52,8 +52,8 @@ def register_foak_model_patch_rules2(
5252
# create model specific rules
5353
rules = [
5454
*gpt_bigcode.get_mp_rules(base_type),
55-
*granite.get_mp_rules(base_type),
56-
*llama.get_mp_rules(base_type),
55+
*granite.get_mp_rules(base_type, config),
56+
*llama.get_mp_rules(base_type, config),
5757
*mistral.get_mp_rules(base_type, config),
5858
*mixtral.get_mp_rules(base_type),
5959
]
@@ -166,7 +166,7 @@ def augmentation(
166166

167167
# wrapper function to register foak patches
168168
# - the base layer setting below will be ignored in non quantized-lora settings
169-
register_foak_model_patch_rules2(
169+
register_foak_model_patch_rules(
170170
base_type=self.configurations["base_layer"],
171171
filter_endswith=terms,
172172
config=model.config,

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# Standard
1616
from functools import partial
17+
import warnings
1718

1819
# Third Party
1920
from fms_acceleration.model_patcher import (
@@ -22,15 +23,23 @@
2223
combine_functions,
2324
combine_triggers,
2425
)
26+
from transformers import PretrainedConfig
2527

2628
# Local
2729
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
2830
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
2931
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
30-
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
32+
from .utils import (
33+
KEY_MLP,
34+
KEY_O,
35+
KEY_QKV,
36+
build_lora_fused_ops,
37+
filter_mp_rules,
38+
trigger_fused_ops,
39+
)
3140

3241

33-
def get_mp_rules(base_type: str):
42+
def get_mp_rules(base_type: str, config: PretrainedConfig = None):
3443
"""
3544
Function to access all patch rules in this module.
3645
If it is a forward_builder rule with `base_type` in
@@ -47,7 +56,7 @@ def get_mp_rules(base_type: str):
4756
except ImportError:
4857
return []
4958

50-
return [
59+
rules = [
5160
# TODO: have a generic version of this rule
5261
# - do regex on RMSNorm class name
5362
# - check on the tensors required for fast_rms_layernorm
@@ -133,3 +142,14 @@ def get_mp_rules(base_type: str):
133142
),
134143
),
135144
]
145+
146+
# perform model specific filtering
147+
if config and config.hidden_act != "silu":
148+
warnings.warn(
149+
f"Granite activation is {config.hdiden_act}, "
150+
"thus disabling LoRA fused-op for MLP, since only SwiGLU "
151+
"is supported. This only affects quantized-peft."
152+
)
153+
rules = filter_mp_rules(rules, {"mlp"}, drop=True)
154+
155+
return rules

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# Standard
1616
from functools import partial
17+
import warnings
1718

1819
# Third Party
1920
from fms_acceleration.model_patcher import (
@@ -22,6 +23,7 @@
2223
combine_functions,
2324
combine_triggers,
2425
)
26+
from transformers import PretrainedConfig
2527
from transformers.models.llama.modeling_llama import (
2628
LlamaAttention,
2729
LlamaMLP,
@@ -32,17 +34,24 @@
3234
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
3335
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3436
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
35-
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
37+
from .utils import (
38+
KEY_MLP,
39+
KEY_O,
40+
KEY_QKV,
41+
build_lora_fused_ops,
42+
filter_mp_rules,
43+
trigger_fused_ops,
44+
)
3645

3746

38-
def get_mp_rules(base_type: str):
47+
def get_mp_rules(base_type: str, config: PretrainedConfig = None):
3948
"""
4049
Function to access all patch rules in this module.
4150
If it is a forward_builder rule with `base_type` in
4251
its forward builder argument, wrap the forward_builder
4352
function as a partial function with the base_type argument
4453
"""
45-
return [
54+
rules = [
4655
# TODO: have a generic version of this rule
4756
# - do regex on RMSNorm class name
4857
# - check on the tensors required for fast_rms_layernorm
@@ -128,3 +137,14 @@ def get_mp_rules(base_type: str):
128137
),
129138
),
130139
]
140+
141+
# perform model specific filtering
142+
if config and config.hidden_act != "silu":
143+
warnings.warn(
144+
f"LLama activation is {config.hdiden_act}, "
145+
"thus disabling LoRA fused-op for MLP, since only SwiGLU "
146+
"is supported. This only affects quantized-peft."
147+
)
148+
rules = filter_mp_rules(rules, {"mlp"}, drop=True)
149+
150+
return rules

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545

4646

47-
def get_mp_rules(base_type: str, config: PretrainedConfig):
47+
def get_mp_rules(base_type: str, config: PretrainedConfig = None):
4848
"""
4949
Function to access all patch rules in this module.
5050
If it is a forward_builder rule with `base_type` in
@@ -130,9 +130,9 @@ def get_mp_rules(base_type: str, config: PretrainedConfig):
130130
]
131131

132132
# perform model specific filtering
133-
if config.hidden_act != "silu":
133+
if config and config.hidden_act != "silu":
134134
warnings.warn(
135-
f"Mixtral activation is {config.hdiden_act}, "
135+
f"Mistral activation is {config.hdiden_act}, "
136136
"thus disabling LoRA fused-op for MLP, since only SwiGLU "
137137
"is supported. This only affects quantized-peft."
138138
)

0 commit comments

Comments
 (0)