|
39 | 39 | GenerationConfig, |
40 | 40 | GenerationMixin, |
41 | 41 | PretrainedConfig, |
| 42 | + is_torch_xpu_available, |
42 | 43 | ) |
43 | 44 | from transformers.dynamic_module_utils import get_class_from_dynamic_module |
44 | 45 | from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput |
|
52 | 53 | from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model |
53 | 54 | from ..generation.modeling import prepare_jit_inputs |
54 | 55 | from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version |
55 | | -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask |
| 56 | +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device |
56 | 57 |
|
57 | 58 |
|
58 | 59 | logger = logging.getLogger(__name__) |
@@ -128,10 +129,14 @@ def __init__( |
128 | 129 | **kwargs, |
129 | 130 | ): |
130 | 131 | OptimizedModel.__init__(self, model=model, config=config) |
131 | | - # To do: add XPU support |
132 | | - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
133 | | - self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 |
| 132 | + if is_torch_xpu_available(check_device=True): |
| 133 | + self._device = torch.device("xpu:0") |
| 134 | + elif torch.cuda.is_available(): |
| 135 | + self._device = torch.device("cuda:0") |
| 136 | + else: |
| 137 | + self._device = torch.device("cpu") |
134 | 138 | self.model.to(self._device) |
| 139 | + self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 |
135 | 140 | self.model_save_dir = model_save_dir |
136 | 141 | self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature) |
137 | 142 |
|
@@ -321,6 +326,8 @@ def _init_warmup(self): |
321 | 326 | if not self._is_ipex_exported: |
322 | 327 | use_cache = "past_key_values" in self.input_names |
323 | 328 | dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) |
| 329 | + if self._device.type != "cpu": |
| 330 | + dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device) |
324 | 331 | for _ in range(2): |
325 | 332 | self(**dummy_inputs) |
326 | 333 |
|
|
0 commit comments