Skip to content

Commit 6ea6b5d

Browse files
ipex Page attn xpu support bug fix (#1053)
* fix ipex xpu support issues Signed-off-by: Liu, Kaixuan <[email protected]> * use `device_map` Signed-off-by: Liu, Kaixuan <[email protected]> * small adjust Signed-off-by: Liu, Kaixuan <[email protected]> * to compatible with openvino Signed-off-by: Liu, Kaixuan <[email protected]> * fix format Signed-off-by: Liu, Kaixuan <[email protected]> * refine code Signed-off-by: Liu, Kaixuan <[email protected]> * Update tests/ipex/test_modeling.py * update code Signed-off-by: Liu, Kaixuan <[email protected]> --------- Signed-off-by: Liu, Kaixuan <[email protected]> Co-authored-by: Ilyas Moutawwakil <[email protected]>
1 parent a6cb0c0 commit 6ea6b5d

File tree

4 files changed

+104
-79
lines changed

4 files changed

+104
-79
lines changed

optimum/exporters/ipex/modeling_utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -744,13 +744,13 @@ def __init__(self, module, config) -> None:
744744
super().__init__()
745745
_setattr_from_module(self, module)
746746
self.config = config
747-
self.module_device = next(module.parameters()).device.type
748-
if self.module_device == "cpu":
747+
self.module_device = next(module.parameters()).device
748+
if self.module_device.type == "cpu":
749749
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
750750
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
751751
self.mlp_linear_add = LinearAdd(module.down_proj)
752752
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
753-
elif self.module_device == "xpu":
753+
elif self.module_device.type == "xpu":
754754
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
755755
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
756756
self.mlp_linear_add = XPULinearAdd(module.down_proj)
@@ -777,15 +777,15 @@ def __init__(self, module, config) -> None:
777777
_setattr_from_module(self, module)
778778
self.config = config
779779
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
780-
self.module_device = next(module.parameters()).device.type
781-
if self.module_device == "cpu":
780+
self.module_device = next(module.parameters()).device
781+
if self.module_device.type == "cpu":
782782
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
783-
elif self.module_device == "xpu":
783+
elif self.module_device.type == "xpu":
784784
self.linear_gelu = XPULinearGelu(module.dense_h_to_4h)
785785
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
786-
if self.module_device == "cpu":
786+
if self.module_device.type == "cpu":
787787
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
788-
elif self.module_device == "xpu":
788+
elif self.module_device.type == "xpu":
789789
self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h)
790790

791791
def forward(
@@ -870,7 +870,11 @@ class _IPEXIntermediate(nn.Module):
870870
def __init__(self, module, config):
871871
super().__init__()
872872
_setattr_from_module(self, module)
873-
self.linear_gelu = LinearGelu(module.dense)
873+
self.module_device = next(module.parameters()).device
874+
if self.module_device.type == "cpu":
875+
self.linear_gelu = LinearGelu(module.dense)
876+
elif self.module_device.type == "xpu":
877+
self.linear_gelu = XPULinearGelu(module.dense)
874878

875879
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
876880
hidden_states = self.linear_gelu(hidden_states)

optimum/intel/pipelines/pipeline_base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,14 +246,17 @@ def load_ipex_model(
246246
SUPPORTED_TASKS,
247247
hub_kwargs: Optional[Dict[str, Any]] = None,
248248
model_kwargs: Optional[Dict[str, Any]] = None,
249+
device_map: Optional[torch.device] = None,
249250
):
250251
hub_kwargs = hub_kwargs or {}
251252
model_kwargs = model_kwargs or {}
252253
ipex_model_class = SUPPORTED_TASKS[targeted_task]["class"][0]
253254

254255
if model is None:
255256
model_id = SUPPORTED_TASKS[targeted_task]["default"]
256-
model = ipex_model_class.from_pretrained(model_id, export=True, **hub_kwargs, **model_kwargs)
257+
model = ipex_model_class.from_pretrained(
258+
model_id, export=True, **hub_kwargs, **model_kwargs, device_map=device_map
259+
)
257260
elif isinstance(model, str):
258261
model_id = model
259262
try:
@@ -262,7 +265,9 @@ def load_ipex_model(
262265
except RuntimeError:
263266
logger.warning("We will use IPEXModel with export=True to export the model")
264267
export = True
265-
model = ipex_model_class.from_pretrained(model, export=export, **hub_kwargs, **model_kwargs)
268+
model = ipex_model_class.from_pretrained(
269+
model, export=export, **hub_kwargs, **model_kwargs, device_map=device_map
270+
)
266271
elif isinstance(model, IPEXModel):
267272
model_id = getattr(model.config, "name_or_path", None)
268273
else:

0 commit comments

Comments
 (0)