Skip to content

Commit d3add81

Browse files
chore(format): run black on dev (#94)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent a8783c6 commit d3add81

File tree

10 files changed

+126
-47
lines changed

10 files changed

+126
-47
lines changed

infer/lib/audio.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray:
2828
am = 32767 * 32768 // am
2929
return np.multiply(audio, am).astype(np.int16)
3030

31+
3132
def float_np_array_to_wav_buf(wav: np.ndarray, sr: int, f32=False) -> BytesIO:
3233
buf = BytesIO()
3334
if f32:
@@ -41,10 +42,12 @@ def float_np_array_to_wav_buf(wav: np.ndarray, sr: int, f32=False) -> BytesIO:
4142
buf.seek(0, 0)
4243
return buf
4344

45+
4446
def save_audio(path: str, audio: np.ndarray, sr: int, f32=False):
4547
with open(path, "wb") as f:
4648
f.write(float_np_array_to_wav_buf(audio, sr, f32).getbuffer())
4749

50+
4851
def wav2(i: BytesIO, o: BufferedWriter, format: str):
4952
inp = av.open(i, "r")
5053
format = video_format_dict.get(format, format)
@@ -65,25 +68,40 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
6568

6669

6770
def load_audio(
68-
file: Union[str, BytesIO, Path],
69-
sr: Optional[int]=None,
70-
format: Optional[str]=None,
71-
mono=True
72-
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
73-
if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()):
71+
file: Union[str, BytesIO, Path],
72+
sr: Optional[int] = None,
73+
format: Optional[str] = None,
74+
mono=True,
75+
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
76+
if (isinstance(file, str) and not Path(file).exists()) or (
77+
isinstance(file, Path) and not file.exists()
78+
):
7479
raise FileNotFoundError(f"File not found: {file}")
7580
rate = 0
7681

7782
container = av.open(file, format=format)
7883
audio_stream = next(s for s in container.streams if s.type == "audio")
7984
channels = 1 if audio_stream.layout == "mono" else 2
8085
container.seek(0)
81-
resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr) if sr is not None else None
86+
resampler = (
87+
AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr)
88+
if sr is not None
89+
else None
90+
)
8291

8392
# Estimated maximum total number of samples to pre-allocate the array
8493
# AV stores length in microseconds by default
85-
estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000
86-
decoded_audio = np.zeros(estimated_total_samples + 1 if channels == 1 else (channels, estimated_total_samples + 1), dtype=np.float32)
94+
estimated_total_samples = (
95+
int(container.duration * sr // 1_000_000) if sr is not None else 48000
96+
)
97+
decoded_audio = np.zeros(
98+
(
99+
estimated_total_samples + 1
100+
if channels == 1
101+
else (channels, estimated_total_samples + 1)
102+
),
103+
dtype=np.float32,
104+
)
87105

88106
offset = 0
89107

@@ -92,7 +110,9 @@ def process_packet(packet: List[AudioFrame]):
92110
rate = 0
93111
for frame in packet:
94112
frame.pts = None # 清除时间戳,避免重新采样问题
95-
resampled_frames = resampler.resample(frame) if resampler is not None else [frame]
113+
resampled_frames = (
114+
resampler.resample(frame) if resampler is not None else [frame]
115+
)
96116
for resampled_frame in resampled_frames:
97117
frame_data = resampled_frame.to_ndarray()
98118
rate = resampled_frame.rate
@@ -104,13 +124,16 @@ def frame_iter(container):
104124
yield p.decode()
105125

106126
for r, frames_data in map(process_packet, frame_iter(container)):
107-
if not rate: rate = r
127+
if not rate:
128+
rate = r
108129
for frame_data in frames_data:
109130
end_index = offset + len(frame_data[0])
110131

111132
# 检查 decoded_audio 是否有足够的空间,并在必要时调整大小
112133
if end_index > decoded_audio.shape[1]:
113-
decoded_audio = np.resize(decoded_audio, (decoded_audio.shape[0], end_index*4))
134+
decoded_audio = np.resize(
135+
decoded_audio, (decoded_audio.shape[0], end_index * 4)
136+
)
114137

115138
np.copyto(decoded_audio[..., offset:end_index], frame_data)
116139
offset += len(frame_data[0])
@@ -126,7 +149,9 @@ def frame_iter(container):
126149
return decoded_audio, rate
127150

128151

129-
def downsample_audio(input_path: str, output_path: str, format: str, br=128_000) -> None:
152+
def downsample_audio(
153+
input_path: str, output_path: str, format: str, br=128_000
154+
) -> None:
130155
"""
131156
default to 128kb/s (equivalent to -q:a 2)
132157
"""

infer/lib/slicer2.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,15 @@ def main():
244244
for i, chunk in enumerate(chunks):
245245
if len(chunk.shape) > 1:
246246
chunk = chunk.T
247-
save_audio(os.path.join(
248-
out,
249-
f"%s_%d.wav"
250-
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
251-
), chunk, sr)
247+
save_audio(
248+
os.path.join(
249+
out,
250+
f"%s_%d.wav"
251+
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
252+
),
253+
chunk,
254+
sr,
255+
)
252256

253257

254258
if __name__ == "__main__":

infer/modules/train/preprocess.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,24 @@ def norm_write(self, tmp_audio, idx0, idx1):
6262
tmp_audio = (tmp_audio / tmp_max * (self.max * self.alpha)) + (
6363
1 - self.alpha
6464
) * tmp_audio
65-
save_audio("%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1), tmp_audio, self.sr, f32=True)
65+
save_audio(
66+
"%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1),
67+
tmp_audio,
68+
self.sr,
69+
f32=True,
70+
)
6671
with open("%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1), "wb") as f:
67-
f.write(float_np_array_to_wav_buf(
68-
load_audio(
69-
float_np_array_to_wav_buf(tmp_audio, self.sr, f32=True),
70-
sr=16000,
71-
format="wav",
72-
)
73-
, 16000, True).getbuffer())
72+
f.write(
73+
float_np_array_to_wav_buf(
74+
load_audio(
75+
float_np_array_to_wav_buf(tmp_audio, self.sr, f32=True),
76+
sr=16000,
77+
format="wav",
78+
),
79+
16000,
80+
True,
81+
).getbuffer()
82+
)
7483

7584
def pipeline(self, path, idx0):
7685
try:

infer/modules/train/train.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,21 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
133133

134134
try:
135135
dist.init_process_group(
136-
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://", world_size=n_gpus, rank=rank
136+
backend=(
137+
"gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"
138+
),
139+
init_method="env://",
140+
world_size=n_gpus,
141+
rank=rank,
137142
)
138143
except:
139144
dist.init_process_group(
140-
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank
145+
backend=(
146+
"gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"
147+
),
148+
init_method="env://?use_libuv=False",
149+
world_size=n_gpus,
150+
rank=rank,
141151
)
142152
torch.manual_seed(hps.train.seed)
143153
if torch.cuda.is_available():
@@ -243,13 +253,17 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
243253
if hasattr(net_g, "module"):
244254
logger.info(
245255
net_g.module.load_state_dict(
246-
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
256+
torch.load(
257+
hps.pretrainG, map_location="cpu", weights_only=True
258+
)["model"]
247259
)
248260
) ##测试不加载优化器
249261
else:
250262
logger.info(
251263
net_g.load_state_dict(
252-
torch.load(hps.pretrainG, map_location="cpu", weights_only=True)["model"]
264+
torch.load(
265+
hps.pretrainG, map_location="cpu", weights_only=True
266+
)["model"]
253267
)
254268
) ##测试不加载优化器
255269
if hps.pretrainD != "":
@@ -258,13 +272,17 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
258272
if hasattr(net_d, "module"):
259273
logger.info(
260274
net_d.module.load_state_dict(
261-
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
275+
torch.load(
276+
hps.pretrainD, map_location="cpu", weights_only=True
277+
)["model"]
262278
)
263279
)
264280
else:
265281
logger.info(
266282
net_d.load_state_dict(
267-
torch.load(hps.pretrainD, map_location="cpu", weights_only=True)["model"]
283+
torch.load(
284+
hps.pretrainD, map_location="cpu", weights_only=True
285+
)["model"]
268286
)
269287
)
270288

infer/modules/uvr5/mdxnet.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,12 @@ def prediction(self, m, vocal_root, others_root, format):
208208
sources = self.demix(mix.T)
209209
opt = sources[0].T
210210
if format in ["wav", "flac"]:
211-
save_audio("%s/vocal_%s.%s" % (vocal_root, basename, format), mix - opt, rate)
212-
save_audio("%s/instrument_%s.%s" % (others_root, basename, format), opt, rate)
211+
save_audio(
212+
"%s/vocal_%s.%s" % (vocal_root, basename, format), mix - opt, rate
213+
)
214+
save_audio(
215+
"%s/instrument_%s.%s" % (others_root, basename, format), opt, rate
216+
)
213217
else:
214218
path_vocal = "%s/vocal_%s.wav" % (vocal_root, basename)
215219
path_other = "%s/instrument_%s.wav" % (others_root, basename)

infer/modules/uvr5/vr.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def __init__(self, agg, model_path, device, is_half, tta=False):
4848
self.mp = mp
4949
self.model = model
5050

51-
def _path_audio_(
52-
self, music_file, ins_root=None, vocal_root=None, format="flac"
53-
):
51+
def _path_audio_(self, music_file, ins_root=None, vocal_root=None, format="flac"):
5452
if ins_root is None and vocal_root is None:
5553
return "No save root."
5654
name = os.path.basename(music_file)
@@ -134,10 +132,14 @@ def _path_audio_(
134132
else:
135133
head = "instrument_"
136134
if format in ["wav", "flac"]:
137-
save_audio(os.path.join(
135+
save_audio(
136+
os.path.join(
138137
ins_root,
139138
head + "{}_{}.{}".format(name, self.data["agg"], format),
140-
), wav_instrument, self.mp.param["sr"])
139+
),
140+
wav_instrument,
141+
self.mp.param["sr"],
142+
)
141143
else:
142144
path = os.path.join(
143145
ins_root, head + "{}_{}.wav".format(name, self.data["agg"])
@@ -162,10 +164,14 @@ def _path_audio_(
162164
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp)
163165
logger.info("%s vocals done" % name)
164166
if format in ["wav", "flac"]:
165-
save_audio(os.path.join(
167+
save_audio(
168+
os.path.join(
166169
vocal_root,
167170
head + "{}_{}.{}".format(name, self.data["agg"], format),
168-
), wav_vocals, self.mp.param["sr"])
171+
),
172+
wav_vocals,
173+
self.mp.param["sr"],
174+
)
169175
else:
170176
path = os.path.join(
171177
vocal_root, head + "{}_{}.wav".format(name, self.data["agg"])

infer/modules/vc/modules.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,16 +252,24 @@ def vc_multi(
252252
try:
253253
tgt_sr, audio_opt = opt
254254
if format1 in ["wav", "flac"]:
255-
save_audio("%s/%s.%s"
256-
% (opt_root, os.path.basename(path), format1), audio_opt, tgt_sr)
255+
save_audio(
256+
"%s/%s.%s"
257+
% (opt_root, os.path.basename(path), format1),
258+
audio_opt,
259+
tgt_sr,
260+
)
257261
else:
258262
path = "%s/%s.%s" % (
259263
opt_root,
260264
os.path.basename(path),
261265
format1,
262266
)
263267
with open(path, "wb") as outf:
264-
wav2(float_np_array_to_wav_buf(audio_opt, tgt_sr), outf, format1)
268+
wav2(
269+
float_np_array_to_wav_buf(audio_opt, tgt_sr),
270+
outf,
271+
format1,
272+
)
265273
except:
266274
info += traceback.format_exc()
267275
infos.append("%s->%s" % (os.path.basename(path), info))

rvc/layers/generators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,11 @@ def _f02sine(self, f0: torch.Tensor, upp: int):
166166
rad = f0 / self.sampling_rate * a
167167
rad2 = torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5
168168
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0)
169-
rad += F.pad(rad_acc, (0, 0, 1, 0), mode='constant')
169+
rad += F.pad(rad_acc, (0, 0, 1, 0), mode="constant")
170170
rad = rad.reshape(f0.shape[0], -1, 1)
171-
b = torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
171+
b = torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(
172+
1, 1, -1
173+
)
172174
rad *= b
173175
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
174176
rand_ini[..., 0] = 0

tools/cmd/onnx/infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020

2121
audio = model.infer(wav, sr, sampling_rate, sid, f0_method, f0_up_key)
2222

23-
save_audio(out_path, audio, sampling_rate)
23+
save_audio(out_path, audio, sampling_rate)

web.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,19 +144,22 @@ def forward_dml(ctx, x, scale):
144144
names = [""]
145145
index_paths = [""]
146146

147+
147148
def lookup_names(weight_root):
148149
global names
149150
for name in os.listdir(weight_root):
150151
if name.endswith(".pth"):
151152
names.append(name)
152153

154+
153155
def lookup_indices(index_root):
154156
global index_paths
155157
for root, _, files in os.walk(index_root, topdown=False):
156158
for name in files:
157159
if name.endswith(".index") and "trained" not in name:
158160
index_paths.append(str(pathlib.Path(root, name)))
159161

162+
160163
lookup_names(weight_root)
161164
lookup_indices(index_root)
162165
lookup_indices(outside_index_root)

0 commit comments

Comments
 (0)