File tree Expand file tree Collapse file tree 2 files changed +10
-0
lines changed
intel_extension_for_pytorch/nn/utils Expand file tree Collapse file tree 2 files changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -46,6 +46,13 @@ def IPEX_WEIGHT_PREPACK_MODULE_CPU():
4646 deepspeed_modules_mapping .update (
4747 {LmHeadLinearAllreduce : _IPEXLmHeadLinearAllreduce }
4848 )
49+ if len (deepspeed_modules ) > 3 :
50+ for module in deepspeed_modules [3 :]:
51+ if module not in deepspeed_modules_mapping :
52+ if issubclass (module , LinearAllreduce ):
53+ deepspeed_modules_mapping [module ] = _IPEXLinearAllreduce
54+ elif issubclass (module , LinearLayer ):
55+ deepspeed_modules_mapping [module ] = _IPEXLinear
4956 torch_modules .update (deepspeed_modules_mapping )
5057
5158 return torch_modules
Original file line number Diff line number Diff line change @@ -101,6 +101,7 @@ def may_import_deepspeed_modules():
101101 try :
102102 # import deepspeed in a global space will raise circular import error
103103 # intel-extension-for-deepspeed imports both IPEX and deepspeed
104+ import deepspeed .module_inject .layers as dslayers
104105 from deepspeed .module_inject .layers import LinearAllreduce , LinearLayer
105106
106107 ds_layers = [LinearAllreduce , LinearLayer ]
@@ -110,6 +111,8 @@ def may_import_deepspeed_modules():
110111 from deepspeed .module_inject .layers import LmHeadLinearAllreduce
111112
112113 ds_layers .append (LmHeadLinearAllreduce )
114+ ds_layers += [cls for cls in dslayers .LinearAllreduce .__subclasses__ ()]
115+ ds_layers += [cls for cls in dslayers .LinearLayer .__subclasses__ ()]
113116 return ds_layers
114117 except ImportError :
115118 return ds_layers
You can’t perform that action at this time.
0 commit comments