diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index fbe6d3297..ae37410ae 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -270,7 +270,9 @@ def reset_shard_pad(self, shard: torch.Tensor, shard_name: str) -> int: # Also ensures a correct parameter count in loading context. shard_meta = self._weight_shard_meta if shard_name == ShardName.weights else self._grad_shard_meta shard_meta.validate(shard) - if self._shard_pad > 0: + # Only count padding for non-empty shards. Frozen FSDPs have empty optimizer shards + # (numel()==0) but non-zero shard_pad, which would incorrectly inflate the count. + if self._shard_pad > 0 and shard.numel() > 0: shard[-self._shard_pad :].zero_() return self._shard_pad return 0