Skip to content

Commit c6fa522

Browse files
committed
related-change with deepspeed#5445
1 parent 89bb319 commit c6fa522

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def IPEX_WEIGHT_PREPACK_MODULE_CPU():
4646
deepspeed_modules_mapping.update(
4747
{LmHeadLinearAllreduce: _IPEXLmHeadLinearAllreduce}
4848
)
49+
if len(deepspeed_modules) > 3:
50+
for module in deepspeed_modules[3:]:
51+
if module not in deepspeed_modules_mapping:
52+
if issubclass(module, LinearAllreduce):
53+
deepspeed_modules_mapping[module] = _IPEXLinearAllreduce
54+
elif issubclass(module, LinearLayer):
55+
deepspeed_modules_mapping[module] = _IPEXLinear
4956
torch_modules.update(deepspeed_modules_mapping)
5057

5158
return torch_modules

intel_extension_for_pytorch/nn/utils/_weight_prepack.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def may_import_deepspeed_modules():
101101
try:
102102
# import deepspeed in a global space will raise circular import error
103103
# intel-extension-for-deepspeed imports both IPEX and deepspeed
104+
import deepspeed.module_inject.layers as dslayers
104105
from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer
105106

106107
ds_layers = [LinearAllreduce, LinearLayer]
@@ -110,6 +111,8 @@ def may_import_deepspeed_modules():
110111
from deepspeed.module_inject.layers import LmHeadLinearAllreduce
111112

112113
ds_layers.append(LmHeadLinearAllreduce)
114+
ds_layers += [cls for cls in dslayers.LinearAllreduce.__subclasses__()]
115+
ds_layers += [cls for cls in dslayers.LinearLayer.__subclasses__()]
113116
return ds_layers
114117
except ImportError:
115118
return ds_layers

0 commit comments

Comments
 (0)