Skip to content

Commit 22c70ad

Browse files
authored
Distributed load padding correction (#439)
1 parent 057aff2 commit 22c70ad

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

fast_llm/engine/multi_stage/fsdp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ def reset_shard_pad(self, shard: torch.Tensor, shard_name: str) -> int:
270270
# Also ensures a correct parameter count in loading context.
271271
shard_meta = self._weight_shard_meta if shard_name == ShardName.weights else self._grad_shard_meta
272272
shard_meta.validate(shard)
273-
if self._shard_pad > 0:
273+
# Only count padding for non-empty shards. Frozen FSDPs have empty optimizer shards
274+
# (numel()==0) but non-zero shard_pad, which would incorrectly inflate the count.
275+
if self._shard_pad > 0 and shard.numel() > 0:
274276
shard[-self._shard_pad :].zero_()
275277
return self._shard_pad
276278
return 0

0 commit comments

Comments
 (0)