diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py index 07a7f772..a54073c3 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py @@ -33,7 +33,6 @@ PreTrainedModel, ) from transformers.modeling_utils import ( - dtype_byte_size, is_local_dist_rank_0, no_init_weights, ) @@ -829,8 +828,9 @@ def shard_checkpoint( block_id = storage_id_to_block[storage_id] sharded_state_dicts[block_id][key] = weight continue - - weight_size = weight.numel() * dtype_byte_size(weight.dtype) + # dtype_byte_size is no more supported in transformers + # due to its inaccuracies - https://github.com/huggingface/transformers/pull/37144 + weight_size = weight.numel() * weight.element_size() # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one # weight in the current shard. if (