Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@
PreTrainedModel,
)
from transformers.modeling_utils import (
dtype_byte_size,
is_local_dist_rank_0,
no_init_weights,
shard_checkpoint,
)
from transformers.pytorch_utils import id_tensor_storage
from transformers.utils import WEIGHTS_NAME
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import convert_file_size_to_int
import accelerate
import torch
import torch.nn as nn
Expand Down Expand Up @@ -688,7 +691,7 @@ def save_quantized(
torch.save(model.state_dict(), join(save_dir, model_save_name))
else:
# Shard checkpoint
shards, index = shard_checkpoint(
shards, index = self.shard_checkpoint(
state_dict, max_shard_size=max_shard_size, weights_name=model_save_name
)

Expand Down Expand Up @@ -766,6 +769,106 @@ def save_quantized(
quantize_config.model_file_base_name = model_base_name
quantize_config.save_pretrained(save_dir)

# added by anh.uong@ibm.com
# adapted from transformers.modeling_utils.shard_checkpoint
# from transformers v4.46, removed in later versions
# TODO: split_torch_state_dict_into_shards from huggingface_hub library
def shard_checkpoint(
Comment on lines +772 to +776
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After transformers v4.46, this method no longer exists in in transformers so I copied it in here to start. The new method to migrate to as per the warning message in the original function says to migrate to split_torch_state_dict_into_shards as noted in the TODO item here. This method was similar but requires more investigation on the difference - https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py#L302

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok this is fine for now

self,
state_dict: Dict[str, torch.Tensor],
max_shard_size: Union[int, str] = "10GB",
weights_name: str = WEIGHTS_NAME,
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.

The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].

<Tip warning={true}>

If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.

</Tip>

Args:
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
The name of the model save file.
"""
logger.warning(
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
"split_torch_state_dict_into_shards from huggingface_hub library"
)
max_shard_size = convert_file_size_to_int(max_shard_size)

sharded_state_dicts = [{}]
last_block_size = 0
total_size = 0
storage_id_to_block = {}

for key, weight in state_dict.items():
# when bnb serialization is used the weights in the state dict can be strings
# check: https://github.com/huggingface/transformers/pull/24416 for more details
if isinstance(weight, str):
continue
else:
storage_id = id_tensor_storage(weight)

# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
if storage_id in storage_id_to_block and weight.device != torch.device(
"meta"
):
block_id = storage_id_to_block[storage_id]
sharded_state_dicts[block_id][key] = weight
continue

weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
# weight in the current shard.
if (
last_block_size + weight_size > max_shard_size
and len(sharded_state_dicts[-1]) > 0
):
sharded_state_dicts.append({})
last_block_size = 0

sharded_state_dicts[-1][key] = weight
last_block_size += weight_size
total_size += weight_size
storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1

# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None

# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(
".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin"
)
shard_file = shard_file.replace(
".safetensors",
f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors",
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file

# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index

def save_pretrained(
self,
save_dir: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
taken from https://github.com/imoneoi/multipack_sampler with some modifications
taken from https://github.com/imoneoi/multipack_sampler with some modifications
taken from https://github.com/instructlab/training/blob/main/src/instructlab/training/multipack_sampler.py
"""

Expand Down
14 changes: 9 additions & 5 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ def __post_init__(self):
self.import_and_maybe_reload is not None,
]
)
!= 1
> 1
):
raise ValueError(
f"Rule '{self.rule_id}' must only have only one of forward, "
"foward builder, or import_and_maybe_reload, specified."
f"Rule '{self.rule_id}' must only have at most one of forward, "
"forward builder, or import_and_maybe_reload, specified."
)

if self.import_and_maybe_reload is not None and self.trigger is not None:
Expand Down Expand Up @@ -425,7 +425,7 @@ def _patch_forwards(
# otherwise triggered
if rule.forward is not None:
forward = rule.forward
else:
elif rule.forward_builder is not None:
fba = {}
if rule.forward_builder_args is not None:
fba = {
Expand All @@ -434,6 +434,9 @@ def _patch_forwards(
if rule.forward_builder_args
}
forward = rule.forward_builder(mod, **fba)
else:
# trigger-only case
forward = None

if isinstance(forward, list):
# this will be list of tuples case
Expand Down Expand Up @@ -468,7 +471,8 @@ def _patch_forwards(
continue

# otherwise
mod.forward = MethodType(forward, mod)
if forward is not None:
mod.forward = MethodType(forward, mod)
ModelPatcher.history.append(
ModelPatcherHistory(
instance=mod_id,
Expand Down
6 changes: 3 additions & 3 deletions plugins/framework/tests/test_model_patcher_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def test_combine_mp_triggers_produces_correct_output(


def test_mp_rule_raises_error_when_arguments_incorrectly_configured():
"Ensure MP rule is throws appropriate error when wrong argument combinations are passed"
"Ensure MP rule throws appropriate error when wrong argument combinations are passed"
# Test mp rule construction raises with multiple arguments
with pytest.raises(
ValueError,
match="must only have only one of forward, "
"foward builder, or import_and_maybe_reload, specified.",
match="must only have at most one of forward, "
"forward builder, or import_and_maybe_reload, specified.",
):
ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def register_foak_model_patch_rules(
FILTER_MAP = {
"fused_lora": {"qkvo", "mlp"},
"fast_loss": {
True: "cross-ent",
True: {"cross-ent", "custom-loss"},
"fused_ce_liger": "fused-lce",
},
"fast_rms_layernorm": "rms",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import triton.language as tl
import torch
from .utils import calculate_settings, MAX_FUSED_SIZE
from typing import Type


@triton.jit
Expand Down Expand Up @@ -290,3 +291,55 @@ def forward(self, input, target):
)
n_items = torch.count_nonzero(target != -100)
return loss.sum() / n_items


# added by flim@sg.ibm.com

# adapted from transformers.loss.loss_utils.ForCausalLMLoss
def FastForCausalLMLoss(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we need to create a similar FastForCausalLMLoss for liger kernel cross entropy?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think we will have a new function for liger cross entropy with the same API. then its a plug and play. But it should be used only if the transformer versioin is advanced

logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)

reduction = "sum" if num_items_in_batch is not None else "mean"
assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is -100 ignore_index, I see that ignore_index is the target value that is ignored and does not contribute to the input gradient, but for CausalLMLoss what is at index -100?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-100 is used extensively throughout HF, while they provide some means for user to change it, almostly nobody will bother to change it

It is the label that is at -100. For a label with that value, we will ignore that token's contribution to the loss

loss = Fast_CrossEntropyLoss.apply(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe the difference between Fast_CrossEntropyLoss and FastCrossEntropyLoss?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Fast_CrossEntropyLoss is the autograd function. we inherit this from unsloth
  • FastCrossEntropyLoss is a specialization of torch.nn.CrossEntropyLoss that serves as a convinienced, implemted using Fast_CrossEntropyLoss

shift_logits, shift_labels
)
if reduction == "sum":
n_items = num_items_in_batch
else:
n_items = torch.count_nonzero(shift_labels != -100)
return loss.sum() / n_items


def replace_custom_loss_when_triggered(
module_cls: Type,
custom_loss_type: str,
):

# this is a special trigger that will perform the replacement
def _trigger(mod):
if isinstance (mod, module_cls) and hasattr(mod, "loss_function"):
# guarded
from transformers.loss.loss_utils import LOSS_MAPPING
LOSS_MAPPING[custom_loss_type] = FastForCausalLMLoss
mod.loss_type = custom_loss_type
return True

return False

return _trigger


Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

# Local
from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.cross_entropy_loss import (
FastCrossEntropyLoss,
replace_custom_loss_when_triggered,
)
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
Expand All @@ -37,6 +40,7 @@
KEY_QKV,
build_lora_fused_ops,
get_hidden_activation_fn_key,
get_transformers_version,
trigger_fused_ops,
)

Expand Down Expand Up @@ -122,16 +126,27 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
base_type=base_type,
),
),
# TODO: have a generic version of this rule
# - get the module_name and reload on that
ModelPatcherRule(
rule_id="granite-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.granite.modeling_granite",
),
),
*[
(
ModelPatcherRule(
rule_id="granite-custom-loss",
trigger=ModelPatcherTrigger(
check=replace_custom_loss_when_triggered(
GraniteForCausalLM, custom_loss_type="granite-custom-loss"
)
),
)
if get_transformers_version() >= "4.46"
else ModelPatcherRule(
rule_id="granite-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.granite.modeling_granite",
),
)
)
],
ModelPatcherRule(
rule_id="granite-fused-lce",
trigger=ModelPatcherTrigger(check=GraniteForCausalLM),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@

# Local
from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.cross_entropy_loss import (
FastCrossEntropyLoss,
replace_custom_loss_when_triggered,
)
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
Expand All @@ -43,6 +46,7 @@
KEY_QKV,
build_lora_fused_ops,
get_hidden_activation_fn_key,
get_transformers_version,
trigger_fused_ops,
)

Expand Down Expand Up @@ -122,14 +126,27 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
trigger=ModelPatcherTrigger(check=LlamaForCausalLM),
forward=lce_forward,
),
ModelPatcherRule(
rule_id="llama-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.llama.modeling_llama",
),
),
*[
(
ModelPatcherRule(
rule_id="llama-custom-loss",
trigger=ModelPatcherTrigger(
check=replace_custom_loss_when_triggered(
LlamaForCausalLM, custom_loss_type="llama-custom-loss"
)
),
)
if get_transformers_version() >= "4.46"
else ModelPatcherRule(
rule_id="llama-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.llama.modeling_llama",
),
)
)
],
# TODO: have a generic version of this rule
# - get the module name
# - check if "apply_rotary_pos_emb" exists
Expand Down
Loading