Skip to content

Commit e7e7c3d

Browse files
committed
add mistral and granite model patch
Signed-off-by: Anh Uong <[email protected]>
1 parent d5a9589 commit e7e7c3d

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
2828
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
2929
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
30+
from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward
3031
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
3132

3233

@@ -40,6 +41,7 @@ def get_mp_rules(base_type: str):
4041
try:
4142
# Third Party
4243
from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel
44+
GraniteForCausalLM,
4345
GraniteAttention,
4446
GraniteMLP,
4547
GraniteRMSNorm,
@@ -120,6 +122,11 @@ def get_mp_rules(base_type: str):
120122
"transformers.models.granite.modeling_granite",
121123
),
122124
),
125+
ModelPatcherRule(
126+
rule_id="granite-fused-lce",
127+
trigger=ModelPatcherTrigger(check=GraniteForCausalLM),
128+
forward=lce_forward,
129+
),
123130
# TODO: have a generic version of this rule
124131
# - get the module name
125132
# - check if "apply_rotary_pos_emb" exists

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@
3333
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
3434
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3535
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
36-
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
37-
3836
from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward
37+
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
3938

4039
def get_mp_rules(base_type: str):
4140
"""

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
combine_triggers,
2424
)
2525
from transformers.models.mistral.modeling_mistral import (
26+
MistralForCausalLM,
2627
MistralAttention,
2728
MistralMLP,
2829
MistralRMSNorm,
@@ -32,9 +33,9 @@
3233
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
3334
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3435
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
36+
from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward
3537
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
3638

37-
3839
def get_mp_rules(base_type):
3940
"""
4041
Function to access all patch rules in this module.
@@ -110,6 +111,11 @@ def get_mp_rules(base_type):
110111
"transformers.models.mistral.modeling_mistral",
111112
),
112113
),
114+
ModelPatcherRule(
115+
rule_id="mistral-fused-lce",
116+
trigger=ModelPatcherTrigger(check=MistralForCausalLM),
117+
forward=lce_forward,
118+
),
113119
ModelPatcherRule(
114120
rule_id="mistral-rope",
115121
import_and_maybe_reload=(

0 commit comments

Comments
 (0)