File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -144,4 +144,21 @@ def wrapped_forward(*args, **kwargs):
144144
145145 model .forward = wrapped_forward
146146
147+ # override fsdp2 set_reduce_scatter_divide_factor API since HCCL does not support PreSumMul yet.
148+ logger .info ("Overriding fsdp2 set_reduce_scatter_divide_factor API since HCCL does not support PreSumMul yet." )
149+ logger .info ("Due to this issue, NPU now does not support HSDP." )
150+ from torch .distributed .fsdp ._fully_shard import _fully_shard as fsdp_fully_shard
151+
152+ def _skip_set_reduce_scatter_divide_factor (* args , ** kwargs ):
153+ # no-op: keep signature flexible so future changes don't crash
154+ return
155+
156+ # monkey-patch the API used by FSDP2
157+ if hasattr (fsdp_fully_shard , "set_reduce_scatter_divide_factor" ):
158+ fsdp_fully_shard .set_reduce_scatter_divide_factor = _skip_set_reduce_scatter_divide_factor
159+ else :
160+ logger .warning (
161+ "fsdp_fully_shard has no attribute set_reduce_scatter_divide_factor; PyTorch version may have changed."
162+ )
163+
147164 return model
You can’t perform that action at this time.
0 commit comments