Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions library/ipex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import sys
import torch
from packaging import version
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
has_ipex = True
except Exception:
has_ipex = False
from .hijacks import ipex_hijacks

torch_version = float(torch.__version__[:3])
torch_version = version.parse(torch.__version__)

# pylint: disable=protected-access, missing-function-docstring, line-too-long

Expand Down Expand Up @@ -56,22 +57,19 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing

if torch_version < 2.3:
if torch_version < version.parse("2.3"):
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
Expand Down Expand Up @@ -114,17 +112,22 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback

if torch_version < 2.5:
if torch_version < version.parse("2.5"):
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu

if torch_version < 2.7:
if torch_version < version.parse("2.7"):
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List

if torch_version < version.parse("2.11"):
torch.cuda._device_t = torch.xpu._device_t
torch.cuda._device = torch.xpu._device
torch.cuda.Union = torch.xpu.Union


# Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
Expand Down Expand Up @@ -160,7 +163,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed

# C
if torch_version < 2.3:
if torch_version < version.parse("2.3"):
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12
Expand Down
Loading