Skip to content

Commit 2769736

Browse files
committed
trigger-only pattern for custom loss
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
1 parent 8787ca1 commit 2769736

File tree

5 files changed

+88
-16
lines changed

5 files changed

+88
-16
lines changed

plugins/framework/src/fms_acceleration/model_patcher.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ def __post_init__(self):
184184
self.import_and_maybe_reload is not None,
185185
]
186186
)
187-
!= 1
187+
> 1
188188
):
189189
raise ValueError(
190-
f"Rule '{self.rule_id}' must only have only one of forward, "
190+
f"Rule '{self.rule_id}' must only have at most one of forward, "
191191
"foward builder, or import_and_maybe_reload, specified."
192192
)
193193

@@ -425,7 +425,7 @@ def _patch_forwards(
425425
# otherwise triggered
426426
if rule.forward is not None:
427427
forward = rule.forward
428-
else:
428+
elif rule.forward_builder is not None:
429429
fba = {}
430430
if rule.forward_builder_args is not None:
431431
fba = {
@@ -434,6 +434,9 @@ def _patch_forwards(
434434
if rule.forward_builder_args
435435
}
436436
forward = rule.forward_builder(mod, **fba)
437+
else:
438+
# trigger-only case
439+
forward = None
437440

438441
if isinstance(forward, list):
439442
# this will be list of tuples case
@@ -468,7 +471,8 @@ def _patch_forwards(
468471
continue
469472

470473
# otherwise
471-
mod.forward = MethodType(forward, mod)
474+
if forward is not None:
475+
mod.forward = MethodType(forward, mod)
472476
ModelPatcher.history.append(
473477
ModelPatcherHistory(
474478
instance=mod_id,

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def register_foak_model_patch_rules(
7373
FILTER_MAP = {
7474
"fused_lora": {"qkvo", "mlp"},
7575
"fast_loss": {
76-
True: "cross-ent",
76+
True: {"cross-ent", "custom-loss"},
7777
"fused_ce_liger": "fused-lce",
7878
},
7979
"fast_rms_layernorm": "rms",

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import triton.language as tl
1717
import torch
1818
from .utils import calculate_settings, MAX_FUSED_SIZE
19+
from typing import Type
1920

2021

2122
@triton.jit
@@ -290,3 +291,55 @@ def forward(self, input, target):
290291
)
291292
n_items = torch.count_nonzero(target != -100)
292293
return loss.sum() / n_items
294+
295+
296+
# added by flim@sg.ibm.com
297+
298+
# adapted from transformers.loss.loss_utils.ForCausalLMLoss
299+
def FastForCausalLMLoss(
300+
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
301+
):
302+
# Upcast to float if we need to compute the loss to avoid potential precision issues
303+
logits = logits.float()
304+
labels = labels.to(logits.device)
305+
# Shift so that tokens < n predict n
306+
shift_logits = logits[..., :-1, :].contiguous()
307+
shift_labels = labels[..., 1:].contiguous()
308+
309+
# Flatten the tokens
310+
shift_logits = shift_logits.view(-1, vocab_size)
311+
shift_labels = shift_labels.view(-1)
312+
# Enable model parallelism
313+
shift_labels = shift_labels.to(shift_logits.device)
314+
315+
reduction = "sum" if num_items_in_batch is not None else "mean"
316+
assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100."
317+
loss = Fast_CrossEntropyLoss.apply(
318+
shift_logits, shift_labels
319+
)
320+
if reduction == "sum":
321+
n_items = num_items_in_batch
322+
else:
323+
n_items = torch.count_nonzero(shift_labels != -100)
324+
return loss.sum() / n_items
325+
326+
327+
def replace_custom_loss_when_triggered(
328+
module_cls: Type,
329+
custom_loss_type: str,
330+
):
331+
332+
# this is a special trigger that will perform the replacement
333+
def _trigger(mod):
334+
if isinstance (mod, module_cls) and hasattr(mod, "loss_function"):
335+
# guarded
336+
from transformers.loss.loss_utils import LOSS_MAPPING
337+
LOSS_MAPPING[custom_loss_type] = FastForCausalLMLoss
338+
mod.loss_type = custom_loss_type
339+
return True
340+
341+
return False
342+
343+
return _trigger
344+
345+

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
# Local
2929
from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward
30-
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
30+
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss, replace_custom_loss_when_triggered
3131
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3232
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
3333
from ..utils import filter_mp_rules
@@ -38,6 +38,7 @@
3838
build_lora_fused_ops,
3939
get_hidden_activation_fn_key,
4040
trigger_fused_ops,
41+
get_transformers_version,
4142
)
4243

4344

@@ -122,16 +123,25 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
122123
base_type=base_type,
123124
),
124125
),
125-
# TODO: have a generic version of this rule
126-
# - get the module_name and reload on that
127-
ModelPatcherRule(
128-
rule_id="granite-cross-ent",
129-
import_and_maybe_reload=(
130-
"torch.nn.CrossEntropyLoss",
131-
FastCrossEntropyLoss,
132-
"transformers.models.granite.modeling_granite",
133-
),
134-
),
126+
*[
127+
ModelPatcherRule(
128+
rule_id="granite-custom-loss",
129+
trigger=ModelPatcherTrigger(
130+
check=replace_custom_loss_when_triggered(
131+
GraniteForCausalLM, custom_loss_type="granite-custom-loss"
132+
)
133+
),
134+
)
135+
if get_transformers_version() >= "4.46" else
136+
ModelPatcherRule(
137+
rule_id="granite-cross-ent",
138+
import_and_maybe_reload=(
139+
"torch.nn.CrossEntropyLoss",
140+
FastCrossEntropyLoss,
141+
"transformers.models.granite.modeling_granite",
142+
),
143+
)
144+
],
135145
ModelPatcherRule(
136146
rule_id="granite-fused-lce",
137147
trigger=ModelPatcherTrigger(check=GraniteForCausalLM),

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Third Party
77
from fms_acceleration.model_patcher import ModelPatcherTrigger
88
from transformers import PretrainedConfig
9+
from transformers.utils.import_utils import _is_package_available
910
import torch
1011

1112
# Local
@@ -214,3 +215,7 @@ def get_hidden_activation_fn_key(config: PretrainedConfig):
214215
"Unable to determine activation function key for "
215216
f"architecture {config.architectures}."
216217
)
218+
219+
def get_transformers_version():
220+
_, _transformers_version = _is_package_available("transformers", return_version=True)
221+
return _transformers_version

0 commit comments

Comments
 (0)