Skip to content

Commit 76b6784

Browse files
Format code (#989)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 7293002 commit 76b6784

File tree

10 files changed

+218
-117
lines changed

10 files changed

+218
-117
lines changed

MDXNet.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,13 @@ def get_models(device, dim_f, dim_t, n_fft):
8383

8484
warnings.filterwarnings("ignore")
8585
import sys
86+
8687
now_dir = os.getcwd()
8788
sys.path.append(now_dir)
8889
from config import Config
8990

9091
cpu = torch.device("cpu")
91-
device=Config().device
92+
device = Config().device
9293
# if torch.cuda.is_available():
9394
# device = torch.device("cuda:0")
9495
# elif torch.backends.mps.is_available():
@@ -104,10 +105,15 @@ def __init__(self, args):
104105
device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft
105106
)
106107
import onnxruntime as ort
108+
107109
print(ort.get_available_providers())
108110
self.model = ort.InferenceSession(
109111
os.path.join(args.onnx, self.model_.target_name + ".onnx"),
110-
providers=["CUDAExecutionProvider", "DmlExecutionProvider","CPUExecutionProvider"],
112+
providers=[
113+
"CUDAExecutionProvider",
114+
"DmlExecutionProvider",
115+
"CPUExecutionProvider",
116+
],
111117
)
112118
print("onnx load done")
113119

config.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self):
3636
self.iscolab,
3737
self.noparallel,
3838
self.noautoopen,
39-
self.dml
39+
self.dml,
4040
) = self.arg_parse()
4141
self.instead = ""
4242
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
@@ -71,7 +71,7 @@ def arg_parse() -> tuple:
7171
cmd_opts.colab,
7272
cmd_opts.noparallel,
7373
cmd_opts.noautoopen,
74-
cmd_opts.dml
74+
cmd_opts.dml,
7575
)
7676

7777
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
@@ -149,26 +149,38 @@ def device_config(self) -> tuple:
149149
if self.dml:
150150
print("use DirectML instead")
151151
try:
152-
os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda")
152+
os.rename(
153+
"runtime\Lib\site-packages\onnxruntime",
154+
"runtime\Lib\site-packages\onnxruntime-cuda",
155+
)
153156
except:
154157
pass
155158
try:
156-
os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime")
159+
os.rename(
160+
"runtime\Lib\site-packages\onnxruntime-dml",
161+
"runtime\Lib\site-packages\onnxruntime",
162+
)
157163
except:
158-
159164
pass
160165
import torch_directml
166+
161167
self.device = torch_directml.device(torch_directml.default_device())
162168
self.is_half = False
163169
else:
164170
if self.instead:
165171
print(f"use {self.instead} instead")
166172
try:
167-
os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda")
173+
os.rename(
174+
"runtime\Lib\site-packages\onnxruntime",
175+
"runtime\Lib\site-packages\onnxruntime-cuda",
176+
)
168177
except:
169178
pass
170179
try:
171-
os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime")
180+
os.rename(
181+
"runtime\Lib\site-packages\onnxruntime-dml",
182+
"runtime\Lib\site-packages\onnxruntime",
183+
)
172184
except:
173185
pass
174186
return x_pad, x_query, x_center, x_max

extract_f0_rmvpe_dml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
exp_dir = sys.argv[1]
1212
import torch_directml
13+
1314
device = torch_directml.device(torch_directml.default_device())
1415
f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
1516

extract_feature_print.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
44
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
55

6-
device=sys.argv[1]
6+
device = sys.argv[1]
77
n_part = int(sys.argv[2])
88
i_part = int(sys.argv[3])
99
if len(sys.argv) == 6:
@@ -20,20 +20,23 @@
2020
import numpy as np
2121
import fairseq
2222

23-
if("privateuseone"not in device):
23+
if "privateuseone" not in device:
2424
device = "cpu"
2525
if torch.cuda.is_available():
2626
device = "cuda"
2727
elif torch.backends.mps.is_available():
2828
device = "mps"
2929
else:
3030
import torch_directml
31+
3132
device = torch_directml.device(torch_directml.default_device())
33+
3234
def forward_dml(ctx, x, scale):
3335
ctx.scale = scale
3436
res = x.clone().detach()
3537
return res
36-
fairseq.modules.grad_multiply.GradMultiply.forward=forward_dml
38+
39+
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
3740

3841
f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
3942

gui_v1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import os, sys,pdb
2-
os.environ["OMP_NUM_THREADS"]="2"
1+
import os, sys, pdb
2+
3+
os.environ["OMP_NUM_THREADS"] = "2"
34
if sys.platform == "darwin":
45
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
56

@@ -47,8 +48,9 @@ def run(self):
4748
import torchaudio.transforms as tat
4849
from i18n import I18nAuto
4950
import rvc_for_realtime
51+
5052
i18n = I18nAuto()
51-
device=rvc_for_realtime.config.device
53+
device = rvc_for_realtime.config.device
5254
# device = torch.device(
5355
# "cuda"
5456
# if torch.cuda.is_available()
@@ -61,7 +63,6 @@ def run(self):
6163
for _ in range(n_cpu):
6264
Harvest(inp_q, opt_q).start()
6365

64-
6566
class GUIConfig:
6667
def __init__(self) -> None:
6768
self.pth_path: str = ""

0 commit comments

Comments
 (0)