Skip to content

Commit b3cd720

Browse files
committed
skip set_reduce_scatter_divide_factor on NPU
1 parent 9370318 commit b3cd720

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

veomni/models/auto.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,21 @@ def wrapped_forward(*args, **kwargs):
144144

145145
model.forward = wrapped_forward
146146

147+
# override fsdp2 set_reduce_scatter_divide_factor API since HCCL does not support PreSumMul yet.
148+
logger.info("Overriding fsdp2 set_reduce_scatter_divide_factor API since HCCL does not support PreSumMul yet.")
149+
logger.info("Due to this issue, NPU now does not support HSDP.")
150+
from torch.distributed.fsdp._fully_shard import _fully_shard as fsdp_fully_shard
151+
152+
def _skip_set_reduce_scatter_divide_factor(*args, **kwargs):
153+
# no-op: keep signature flexible so future changes don't crash
154+
return
155+
156+
# monkey-patch the API used by FSDP2
157+
if hasattr(fsdp_fully_shard, "set_reduce_scatter_divide_factor"):
158+
fsdp_fully_shard.set_reduce_scatter_divide_factor = _skip_set_reduce_scatter_divide_factor
159+
else:
160+
logger.warning(
161+
"fsdp_fully_shard has no attribute set_reduce_scatter_divide_factor; PyTorch version may have changed."
162+
)
163+
147164
return model

0 commit comments

Comments
 (0)