Skip to content

Commit e936e24

Browse files
committed
optimize(infer): move ipex into rvc
1 parent cc11ad4 commit e936e24

File tree

8 files changed

+18
-12
lines changed

8 files changed

+18
-12
lines changed

configs/config.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,7 @@
66
from multiprocessing import cpu_count
77

88
import torch
9-
10-
try:
11-
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
12-
13-
if torch.xpu.is_available():
14-
from infer.modules.ipex import ipex_init
15-
16-
ipex_init()
17-
except Exception: # pylint: disable=broad-exception-caught
18-
pass
9+
# TODO: move device selection into rvc
1910
import logging
2011

2112
logger = logging.getLogger(__name__)

infer/modules/train/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
2525

2626
if torch.xpu.is_available():
27-
from infer.modules.ipex import ipex_init
28-
from infer.modules.ipex.gradscaler import gradscaler_init
27+
from rvc.ipex import ipex_init, gradscaler_init
2928
from torch.xpu.amp import autocast
3029

3130
GradScaler = gradscaler_init()

rvc/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import ipex
2+
import sys
3+
del sys.modules["rvc.ipex"]
4+

rvc/ipex/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
try:
2+
import torch
3+
if torch.xpu.is_available():
4+
from .init import ipex_init
5+
ipex_init()
6+
from .gradscaler import gradscaler_init
7+
except Exception: # pylint: disable=broad-exception-caught
8+
pass
File renamed without changes.

infer/modules/ipex/gradscaler.py renamed to rvc/ipex/gradscaler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import defaultdict
2+
23
import torch
34
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
45
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import

infer/modules/ipex/hijacks.py renamed to rvc/ipex/hijacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import importlib
3+
34
import torch
45
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
56

infer/modules/ipex/__init__.py renamed to rvc/ipex/init.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
import sys
33
import contextlib
4+
45
import torch
56
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
7+
68
from .hijacks import ipex_hijacks
79
from .attention import attention_init
810

0 commit comments

Comments
 (0)