Skip to content

Commit 24bdadb

Browse files
anhuongfabianlim
andauthored
fix: cross entropy for transformers>4.45 (#123)
* trigger-only pattern for custom loss Signed-off-by: Yu Chin Fabian Lim <[email protected]> * add cross ent fix for llama, mistral, mixtral Signed-off-by: Anh Uong <[email protected]> * fix linting errors Signed-off-by: Anh Uong <[email protected]> * run formatter Signed-off-by: Anh Uong <[email protected]> * fix misspelling and error test Signed-off-by: Anh Uong <[email protected]> * fix import error with later transformers Signed-off-by: Anh Uong <[email protected]> * add benchmarks Signed-off-by: Anh Uong <[email protected]> * fix import order Signed-off-by: Anh Uong <[email protected]> * replace benchmark and requirements Signed-off-by: Anh Uong <[email protected]> --------- Signed-off-by: Yu Chin Fabian Lim <[email protected]> Signed-off-by: Anh Uong <[email protected]> Co-authored-by: Yu Chin Fabian Lim <[email protected]>
1 parent 8787ca1 commit 24bdadb

File tree

14 files changed

+465
-224
lines changed

14 files changed

+465
-224
lines changed

plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@
3333
PreTrainedModel,
3434
)
3535
from transformers.modeling_utils import (
36+
dtype_byte_size,
3637
is_local_dist_rank_0,
3738
no_init_weights,
38-
shard_checkpoint,
3939
)
40+
from transformers.pytorch_utils import id_tensor_storage
41+
from transformers.utils import WEIGHTS_NAME
4042
from transformers.utils.generic import ContextManagers
43+
from transformers.utils.hub import convert_file_size_to_int
4144
import accelerate
4245
import torch
4346
import torch.nn as nn
@@ -688,7 +691,7 @@ def save_quantized(
688691
torch.save(model.state_dict(), join(save_dir, model_save_name))
689692
else:
690693
# Shard checkpoint
691-
shards, index = shard_checkpoint(
694+
shards, index = self.shard_checkpoint(
692695
state_dict, max_shard_size=max_shard_size, weights_name=model_save_name
693696
)
694697

@@ -766,6 +769,106 @@ def save_quantized(
766769
quantize_config.model_file_base_name = model_base_name
767770
quantize_config.save_pretrained(save_dir)
768771

772+
# added by [email protected]
773+
# adapted from transformers.modeling_utils.shard_checkpoint
774+
# from transformers v4.46, removed in later versions
775+
# TODO: split_torch_state_dict_into_shards from huggingface_hub library
776+
def shard_checkpoint(
777+
self,
778+
state_dict: Dict[str, torch.Tensor],
779+
max_shard_size: Union[int, str] = "10GB",
780+
weights_name: str = WEIGHTS_NAME,
781+
):
782+
"""
783+
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
784+
given size.
785+
786+
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
787+
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
788+
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
789+
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
790+
791+
<Tip warning={true}>
792+
793+
If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will
794+
have a size greater than `max_shard_size`.
795+
796+
</Tip>
797+
798+
Args:
799+
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
800+
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
801+
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
802+
(like `"5MB"`).
803+
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
804+
The name of the model save file.
805+
"""
806+
logger.warning(
807+
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
808+
"split_torch_state_dict_into_shards from huggingface_hub library"
809+
)
810+
max_shard_size = convert_file_size_to_int(max_shard_size)
811+
812+
sharded_state_dicts = [{}]
813+
last_block_size = 0
814+
total_size = 0
815+
storage_id_to_block = {}
816+
817+
for key, weight in state_dict.items():
818+
# when bnb serialization is used the weights in the state dict can be strings
819+
# check: https://github.com/huggingface/transformers/pull/24416 for more details
820+
if isinstance(weight, str):
821+
continue
822+
else:
823+
storage_id = id_tensor_storage(weight)
824+
825+
# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
826+
if storage_id in storage_id_to_block and weight.device != torch.device(
827+
"meta"
828+
):
829+
block_id = storage_id_to_block[storage_id]
830+
sharded_state_dicts[block_id][key] = weight
831+
continue
832+
833+
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
834+
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
835+
# weight in the current shard.
836+
if (
837+
last_block_size + weight_size > max_shard_size
838+
and len(sharded_state_dicts[-1]) > 0
839+
):
840+
sharded_state_dicts.append({})
841+
last_block_size = 0
842+
843+
sharded_state_dicts[-1][key] = weight
844+
last_block_size += weight_size
845+
total_size += weight_size
846+
storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1
847+
848+
# If we only have one shard, we return it
849+
if len(sharded_state_dicts) == 1:
850+
return {weights_name: sharded_state_dicts[0]}, None
851+
852+
# Otherwise, let's build the index
853+
weight_map = {}
854+
shards = {}
855+
for idx, shard in enumerate(sharded_state_dicts):
856+
shard_file = weights_name.replace(
857+
".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin"
858+
)
859+
shard_file = shard_file.replace(
860+
".safetensors",
861+
f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors",
862+
)
863+
shards[shard_file] = shard
864+
for key in shard.keys():
865+
weight_map[key] = shard_file
866+
867+
# Add the metadata
868+
metadata = {"total_size": total_size}
869+
index = {"metadata": metadata, "weight_map": weight_map}
870+
return shards, index
871+
769872
def save_pretrained(
770873
self,
771874
save_dir: str,

plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/multipack_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2121
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2222
SOFTWARE.
23-
taken from https://github.com/imoneoi/multipack_sampler with some modifications
23+
taken from https://github.com/imoneoi/multipack_sampler with some modifications
2424
taken from https://github.com/instructlab/training/blob/main/src/instructlab/training/multipack_sampler.py
2525
"""
2626

plugins/framework/src/fms_acceleration/model_patcher.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,11 @@ def __post_init__(self):
184184
self.import_and_maybe_reload is not None,
185185
]
186186
)
187-
!= 1
187+
> 1
188188
):
189189
raise ValueError(
190-
f"Rule '{self.rule_id}' must only have only one of forward, "
191-
"foward builder, or import_and_maybe_reload, specified."
190+
f"Rule '{self.rule_id}' must only have at most one of forward, "
191+
"forward builder, or import_and_maybe_reload, specified."
192192
)
193193

194194
if self.import_and_maybe_reload is not None and self.trigger is not None:
@@ -425,7 +425,7 @@ def _patch_forwards(
425425
# otherwise triggered
426426
if rule.forward is not None:
427427
forward = rule.forward
428-
else:
428+
elif rule.forward_builder is not None:
429429
fba = {}
430430
if rule.forward_builder_args is not None:
431431
fba = {
@@ -434,6 +434,9 @@ def _patch_forwards(
434434
if rule.forward_builder_args
435435
}
436436
forward = rule.forward_builder(mod, **fba)
437+
else:
438+
# trigger-only case
439+
forward = None
437440

438441
if isinstance(forward, list):
439442
# this will be list of tuples case
@@ -468,7 +471,8 @@ def _patch_forwards(
468471
continue
469472

470473
# otherwise
471-
mod.forward = MethodType(forward, mod)
474+
if forward is not None:
475+
mod.forward = MethodType(forward, mod)
472476
ModelPatcher.history.append(
473477
ModelPatcherHistory(
474478
instance=mod_id,

plugins/framework/tests/test_model_patcher_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,12 @@ def test_combine_mp_triggers_produces_correct_output(
254254

255255

256256
def test_mp_rule_raises_error_when_arguments_incorrectly_configured():
257-
"Ensure MP rule is throws appropriate error when wrong argument combinations are passed"
257+
"Ensure MP rule throws appropriate error when wrong argument combinations are passed"
258258
# Test mp rule construction raises with multiple arguments
259259
with pytest.raises(
260260
ValueError,
261-
match="must only have only one of forward, "
262-
"foward builder, or import_and_maybe_reload, specified.",
261+
match="must only have at most one of forward, "
262+
"forward builder, or import_and_maybe_reload, specified.",
263263
):
264264
ModelPatcherRule(
265265
rule_id=DUMMY_RULE_ID,

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def register_foak_model_patch_rules(
7373
FILTER_MAP = {
7474
"fused_lora": {"qkvo", "mlp"},
7575
"fast_loss": {
76-
True: "cross-ent",
76+
True: {"cross-ent", "custom-loss"},
7777
"fused_ce_liger": "fused-lce",
7878
},
7979
"fast_rms_layernorm": "rms",

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import triton.language as tl
1717
import torch
1818
from .utils import calculate_settings, MAX_FUSED_SIZE
19+
from typing import Type
1920

2021

2122
@triton.jit
@@ -290,3 +291,55 @@ def forward(self, input, target):
290291
)
291292
n_items = torch.count_nonzero(target != -100)
292293
return loss.sum() / n_items
294+
295+
296+
297+
298+
# adapted from transformers.loss.loss_utils.ForCausalLMLoss
299+
def FastForCausalLMLoss(
300+
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
301+
):
302+
# Upcast to float if we need to compute the loss to avoid potential precision issues
303+
logits = logits.float()
304+
labels = labels.to(logits.device)
305+
# Shift so that tokens < n predict n
306+
shift_logits = logits[..., :-1, :].contiguous()
307+
shift_labels = labels[..., 1:].contiguous()
308+
309+
# Flatten the tokens
310+
shift_logits = shift_logits.view(-1, vocab_size)
311+
shift_labels = shift_labels.view(-1)
312+
# Enable model parallelism
313+
shift_labels = shift_labels.to(shift_logits.device)
314+
315+
reduction = "sum" if num_items_in_batch is not None else "mean"
316+
assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100."
317+
loss = Fast_CrossEntropyLoss.apply(
318+
shift_logits, shift_labels
319+
)
320+
if reduction == "sum":
321+
n_items = num_items_in_batch
322+
else:
323+
n_items = torch.count_nonzero(shift_labels != -100)
324+
return loss.sum() / n_items
325+
326+
327+
def replace_custom_loss_when_triggered(
328+
module_cls: Type,
329+
custom_loss_type: str,
330+
):
331+
332+
# this is a special trigger that will perform the replacement
333+
def _trigger(mod):
334+
if isinstance (mod, module_cls) and hasattr(mod, "loss_function"):
335+
# guarded
336+
from transformers.loss.loss_utils import LOSS_MAPPING
337+
LOSS_MAPPING[custom_loss_type] = FastForCausalLMLoss
338+
mod.loss_type = custom_loss_type
339+
return True
340+
341+
return False
342+
343+
return _trigger
344+
345+

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727

2828
# Local
2929
from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward
30-
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
30+
from ..kernels.unsloth.cross_entropy_loss import (
31+
FastCrossEntropyLoss,
32+
replace_custom_loss_when_triggered,
33+
)
3134
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3235
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
3336
from ..utils import filter_mp_rules
@@ -37,6 +40,7 @@
3740
KEY_QKV,
3841
build_lora_fused_ops,
3942
get_hidden_activation_fn_key,
43+
get_transformers_version,
4044
trigger_fused_ops,
4145
)
4246

@@ -122,16 +126,27 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
122126
base_type=base_type,
123127
),
124128
),
125-
# TODO: have a generic version of this rule
126-
# - get the module_name and reload on that
127-
ModelPatcherRule(
128-
rule_id="granite-cross-ent",
129-
import_and_maybe_reload=(
130-
"torch.nn.CrossEntropyLoss",
131-
FastCrossEntropyLoss,
132-
"transformers.models.granite.modeling_granite",
133-
),
134-
),
129+
*[
130+
(
131+
ModelPatcherRule(
132+
rule_id="granite-custom-loss",
133+
trigger=ModelPatcherTrigger(
134+
check=replace_custom_loss_when_triggered(
135+
GraniteForCausalLM, custom_loss_type="granite-custom-loss"
136+
)
137+
),
138+
)
139+
if get_transformers_version() >= "4.46"
140+
else ModelPatcherRule(
141+
rule_id="granite-cross-ent",
142+
import_and_maybe_reload=(
143+
"torch.nn.CrossEntropyLoss",
144+
FastCrossEntropyLoss,
145+
"transformers.models.granite.modeling_granite",
146+
),
147+
)
148+
)
149+
],
135150
ModelPatcherRule(
136151
rule_id="granite-fused-lce",
137152
trigger=ModelPatcherTrigger(check=GraniteForCausalLM),

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333

3434
# Local
3535
from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward
36-
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
36+
from ..kernels.unsloth.cross_entropy_loss import (
37+
FastCrossEntropyLoss,
38+
replace_custom_loss_when_triggered,
39+
)
3740
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
3841
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
3942
from ..utils import filter_mp_rules
@@ -43,6 +46,7 @@
4346
KEY_QKV,
4447
build_lora_fused_ops,
4548
get_hidden_activation_fn_key,
49+
get_transformers_version,
4650
trigger_fused_ops,
4751
)
4852

@@ -122,14 +126,27 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
122126
trigger=ModelPatcherTrigger(check=LlamaForCausalLM),
123127
forward=lce_forward,
124128
),
125-
ModelPatcherRule(
126-
rule_id="llama-cross-ent",
127-
import_and_maybe_reload=(
128-
"torch.nn.CrossEntropyLoss",
129-
FastCrossEntropyLoss,
130-
"transformers.models.llama.modeling_llama",
131-
),
132-
),
129+
*[
130+
(
131+
ModelPatcherRule(
132+
rule_id="llama-custom-loss",
133+
trigger=ModelPatcherTrigger(
134+
check=replace_custom_loss_when_triggered(
135+
LlamaForCausalLM, custom_loss_type="llama-custom-loss"
136+
)
137+
),
138+
)
139+
if get_transformers_version() >= "4.46"
140+
else ModelPatcherRule(
141+
rule_id="llama-cross-ent",
142+
import_and_maybe_reload=(
143+
"torch.nn.CrossEntropyLoss",
144+
FastCrossEntropyLoss,
145+
"transformers.models.llama.modeling_llama",
146+
),
147+
)
148+
)
149+
],
133150
# TODO: have a generic version of this rule
134151
# - get the module name
135152
# - check if "apply_rotary_pos_emb" exists

0 commit comments

Comments
 (0)