Skip to content

Commit 1b6d7a7

Browse files
committed
handle requires_grad in shard_moe
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent feaeaa5 commit 1b6d7a7

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def load_sharded_experts_onto_device(
215215
else:
216216
mod_dtype = getattr(mod, name).dtype
217217

218+
requires_grad = getattr(mod, name).requires_grad
219+
218220
# the megablocks dmoe experts the expert features to be on DIM_EXPERT.
219221
# - concat on dim 0 and distribute
220222
# - cast to the correct dtype for the module
@@ -227,7 +229,8 @@ def load_sharded_experts_onto_device(
227229
_placements = [Replicate() for _ in range(len(placements))]
228230

229231
param = torch.nn.Parameter(
230-
distribute_tensor(param, device_mesh, _placements)
232+
distribute_tensor(param, device_mesh, _placements),
233+
requires_grad=requires_grad
231234
)
232235

233236
# register the sharded parameter onto the megablocks.dmoe

0 commit comments

Comments
 (0)