Skip to content

Commit bb6e04e

Browse files
committed
add cross ent fix for llama, mistral, mixtral
Signed-off-by: Anh Uong <[email protected]>
1 parent 2769736 commit bb6e04e

File tree

3 files changed

+64
-28
lines changed

3 files changed

+64
-28
lines changed

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
# Local
3535
from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward
36-
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
36+
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss, replace_custom_loss_when_triggered
3737
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3838
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
3939
from ..utils import filter_mp_rules
@@ -44,6 +44,7 @@
4444
build_lora_fused_ops,
4545
get_hidden_activation_fn_key,
4646
trigger_fused_ops,
47+
get_transformers_version,
4748
)
4849

4950

@@ -122,14 +123,25 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
122123
trigger=ModelPatcherTrigger(check=LlamaForCausalLM),
123124
forward=lce_forward,
124125
),
125-
ModelPatcherRule(
126-
rule_id="llama-cross-ent",
127-
import_and_maybe_reload=(
128-
"torch.nn.CrossEntropyLoss",
129-
FastCrossEntropyLoss,
130-
"transformers.models.llama.modeling_llama",
131-
),
132-
),
126+
*[
127+
ModelPatcherRule(
128+
rule_id="llama-custom-loss",
129+
trigger=ModelPatcherTrigger(
130+
check=replace_custom_loss_when_triggered(
131+
LlamaForCausalLM, custom_loss_type="llama-custom-loss"
132+
)
133+
),
134+
)
135+
if get_transformers_version() >= "4.46" else
136+
ModelPatcherRule(
137+
rule_id="llama-cross-ent",
138+
import_and_maybe_reload=(
139+
"torch.nn.CrossEntropyLoss",
140+
FastCrossEntropyLoss,
141+
"transformers.models.llama.modeling_llama",
142+
),
143+
)
144+
],
133145
# TODO: have a generic version of this rule
134146
# - get the module name
135147
# - check if "apply_rotary_pos_emb" exists

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
# Local
3535
from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward
36-
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
36+
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss, replace_custom_loss_when_triggered
3737
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3838
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
3939
from ..utils import filter_mp_rules
@@ -44,6 +44,7 @@
4444
build_lora_fused_ops,
4545
get_hidden_activation_fn_key,
4646
trigger_fused_ops,
47+
get_transformers_version,
4748
)
4849

4950

@@ -114,14 +115,25 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
114115
base_type=base_type,
115116
),
116117
),
117-
ModelPatcherRule(
118-
rule_id="mistral-cross-ent",
119-
import_and_maybe_reload=(
120-
"torch.nn.CrossEntropyLoss",
121-
FastCrossEntropyLoss,
122-
"transformers.models.mistral.modeling_mistral",
123-
),
124-
),
118+
*[
119+
ModelPatcherRule(
120+
rule_id="mistral-custom-loss",
121+
trigger=ModelPatcherTrigger(
122+
check=replace_custom_loss_when_triggered(
123+
MistralForCausalLM, custom_loss_type="mistral-custom-loss"
124+
)
125+
),
126+
)
127+
if get_transformers_version() >= "4.46" else
128+
ModelPatcherRule(
129+
rule_id="mistral-cross-ent",
130+
import_and_maybe_reload=(
131+
"torch.nn.CrossEntropyLoss",
132+
FastCrossEntropyLoss,
133+
"transformers.models.mistral.modeling_mistral",
134+
),
135+
)
136+
],
125137
ModelPatcherRule(
126138
rule_id="mistral-fused-lce",
127139
trigger=ModelPatcherTrigger(check=MistralForCausalLM),

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
from transformers.models.mixtral.modeling_mixtral import (
2626
MixtralAttention,
2727
MixtralRMSNorm,
28+
MixtralForCausalLM,
2829
)
2930

3031
# Local
31-
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
32+
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss, replace_custom_loss_when_triggered
3233
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3334
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
34-
from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
35+
from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops, get_transformers_version
3536

3637

3738
def get_mp_rules(base_type):
@@ -85,14 +86,25 @@ def get_mp_rules(base_type):
8586
logic="APPEND",
8687
),
8788
),
88-
ModelPatcherRule(
89-
rule_id="mixtral-cross-ent",
90-
import_and_maybe_reload=(
91-
"torch.nn.CrossEntropyLoss",
92-
FastCrossEntropyLoss,
93-
"transformers.models.mixtral.modeling_mixtral",
94-
),
95-
),
89+
*[
90+
ModelPatcherRule(
91+
rule_id="mixtral-custom-loss",
92+
trigger=ModelPatcherTrigger(
93+
check=replace_custom_loss_when_triggered(
94+
MixtralForCausalLM, custom_loss_type="mixtral-custom-loss"
95+
)
96+
),
97+
)
98+
if get_transformers_version() >= "4.46" else
99+
ModelPatcherRule(
100+
rule_id="mixtral-cross-ent",
101+
import_and_maybe_reload=(
102+
"torch.nn.CrossEntropyLoss",
103+
FastCrossEntropyLoss,
104+
"transformers.models.mixtral.modeling_mixtral",
105+
),
106+
)
107+
],
96108
ModelPatcherRule(
97109
rule_id="mixtral-rope",
98110
import_and_maybe_reload=(

0 commit comments

Comments
 (0)