Skip to content

Commit 79a79c3

Browse files
authored
Update config.py
1 parent 28948f8 commit 79a79c3

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

config.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
import torch
33
from multiprocessing import cpu_count
44

5+
def config_file_change_fp32():
6+
for config_file in ["32k.json", "40k.json", "48k.json"]:
7+
with open(f"configs/{config_file}", "r") as f:
8+
strr = f.read().replace("true", "false")
9+
with open(f"configs/{config_file}", "w") as f:
10+
f.write(strr)
11+
with open("trainset_preprocess_pipeline_print.py", "r") as f:
12+
strr = f.read().replace("3.7", "3.0")
13+
with open("trainset_preprocess_pipeline_print.py", "w") as f:
14+
f.write(strr)
515

616
class Config:
717
def __init__(self):
@@ -60,15 +70,7 @@ def device_config(self) -> tuple:
6070
):
6171
print("16系/10系显卡和P40强制单精度")
6272
self.is_half = False
63-
for config_file in ["32k.json", "40k.json", "48k.json"]:
64-
with open(f"configs/{config_file}", "r") as f:
65-
strr = f.read().replace("true", "false")
66-
with open(f"configs/{config_file}", "w") as f:
67-
f.write(strr)
68-
with open("trainset_preprocess_pipeline_print.py", "r") as f:
69-
strr = f.read().replace("3.7", "3.0")
70-
with open("trainset_preprocess_pipeline_print.py", "w") as f:
71-
f.write(strr)
73+
config_file_change_fp32()
7274
else:
7375
self.gpu_name = None
7476
self.gpu_mem = int(
@@ -87,10 +89,12 @@ def device_config(self) -> tuple:
8789
print("没有发现支持的N卡, 使用MPS进行推理")
8890
self.device = "mps"
8991
self.is_half = False
92+
config_file_change_fp32()
9093
else:
9194
print("没有发现支持的N卡, 使用CPU进行推理")
9295
self.device = "cpu"
9396
self.is_half = False
97+
config_file_change_fp32()
9498

9599
if self.n_cpu == 0:
96100
self.n_cpu = cpu_count()

0 commit comments

Comments
 (0)