@@ -744,13 +744,13 @@ def __init__(self, module, config) -> None:
744744 super ().__init__ ()
745745 _setattr_from_module (self , module )
746746 self .config = config
747- self .module_device = next (module .parameters ()).device . type
748- if self .module_device == "cpu" :
747+ self .module_device = next (module .parameters ()).device
748+ if self .module_device . type == "cpu" :
749749 # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
750750 if module .down_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
751751 self .mlp_linear_add = LinearAdd (module .down_proj )
752752 self .linear_silu_mul = Linear2SiluMul (module .gate_proj , module .up_proj )
753- elif self .module_device == "xpu" :
753+ elif self .module_device . type == "xpu" :
754754 # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
755755 if module .down_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
756756 self .mlp_linear_add = XPULinearAdd (module .down_proj )
@@ -777,15 +777,15 @@ def __init__(self, module, config) -> None:
777777 _setattr_from_module (self , module )
778778 self .config = config
779779 # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
780- self .module_device = next (module .parameters ()).device . type
781- if self .module_device == "cpu" :
780+ self .module_device = next (module .parameters ()).device
781+ if self .module_device . type == "cpu" :
782782 self .linear_gelu = LinearGelu (module .dense_h_to_4h )
783- elif self .module_device == "xpu" :
783+ elif self .module_device . type == "xpu" :
784784 self .linear_gelu = XPULinearGelu (module .dense_h_to_4h )
785785 if module .dense_4h_to_h .__class__ .__name__ not in ["LinearAllreduce" ]:
786- if self .module_device == "cpu" :
786+ if self .module_device . type == "cpu" :
787787 self .linear_add_add = LinearAddAdd (module .dense_4h_to_h )
788- elif self .module_device == "xpu" :
788+ elif self .module_device . type == "xpu" :
789789 self .linear_add_add = XPUlinearAddAdd (module .dense_4h_to_h )
790790
791791 def forward (
@@ -870,7 +870,11 @@ class _IPEXIntermediate(nn.Module):
870870 def __init__ (self , module , config ):
871871 super ().__init__ ()
872872 _setattr_from_module (self , module )
873- self .linear_gelu = LinearGelu (module .dense )
873+ self .module_device = next (module .parameters ()).device
874+ if self .module_device .type == "cpu" :
875+ self .linear_gelu = LinearGelu (module .dense )
876+ elif self .module_device .type == "xpu" :
877+ self .linear_gelu = XPULinearGelu (module .dense )
874878
875879 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
876880 hidden_states = self .linear_gelu (hidden_states )
0 commit comments