Skip to content

Commit a8783c6

Browse files
authored
optimize: some training optimizations (#95)
* optimzie(train&uvr5): rm sf & simp. AudioPre * fix(audio): too many mallocs * feat(audio): load_audio support stereo * fix(audio): float32 wav saving * fix(train): missing ckpt var
1 parent f4644ec commit a8783c6

File tree

19 files changed

+164
-434
lines changed

19 files changed

+164
-434
lines changed

infer/lib/audio.py

Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from io import BufferedWriter, BytesIO
22
from pathlib import Path
3-
from typing import Dict, Tuple
3+
from typing import Dict, Tuple, Optional, Union, List
44
import os
5+
import math
6+
import wave
57

68
import numpy as np
9+
from numba import jit
710
import av
811
from av.audio.resampler import AudioResampler
12+
from av.audio.frame import AudioFrame
13+
import scipy.io.wavfile as wavfile
914

1015
video_format_dict: Dict[str, str] = {
1116
"m4a": "mp4",
@@ -17,6 +22,29 @@
1722
}
1823

1924

25+
@jit(nopython=True)
26+
def float_to_int16(audio: np.ndarray) -> np.ndarray:
27+
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
28+
am = 32767 * 32768 // am
29+
return np.multiply(audio, am).astype(np.int16)
30+
31+
def float_np_array_to_wav_buf(wav: np.ndarray, sr: int, f32=False) -> BytesIO:
32+
buf = BytesIO()
33+
if f32:
34+
wavfile.write(buf, sr, wav.astype(np.float32))
35+
else:
36+
with wave.open(buf, "wb") as wf:
37+
wf.setnchannels(2 if len(wav.shape) > 1 else 1)
38+
wf.setsampwidth(2) # Sample width in bytes
39+
wf.setframerate(sr) # Sample rate in Hz
40+
wf.writeframes(float_to_int16(wav.T if len(wav.shape) > 1 else wav))
41+
buf.seek(0, 0)
42+
return buf
43+
44+
def save_audio(path: str, audio: np.ndarray, sr: int, f32=False):
45+
with open(path, "wb") as f:
46+
f.write(float_np_array_to_wav_buf(audio, sr, f32).getbuffer())
47+
2048
def wav2(i: BytesIO, o: BufferedWriter, format: str):
2149
inp = av.open(i, "r")
2250
format = video_format_dict.get(format, format)
@@ -36,43 +64,72 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
3664
inp.close()
3765

3866

39-
def load_audio(file: str, sr: int) -> np.ndarray:
40-
if not Path(file).exists():
67+
def load_audio(
68+
file: Union[str, BytesIO, Path],
69+
sr: Optional[int]=None,
70+
format: Optional[str]=None,
71+
mono=True
72+
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
73+
if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()):
4174
raise FileNotFoundError(f"File not found: {file}")
75+
rate = 0
76+
77+
container = av.open(file, format=format)
78+
audio_stream = next(s for s in container.streams if s.type == "audio")
79+
channels = 1 if audio_stream.layout == "mono" else 2
80+
container.seek(0)
81+
resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr) if sr is not None else None
82+
83+
# Estimated maximum total number of samples to pre-allocate the array
84+
# AV stores length in microseconds by default
85+
estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000
86+
decoded_audio = np.zeros(estimated_total_samples + 1 if channels == 1 else (channels, estimated_total_samples + 1), dtype=np.float32)
87+
88+
offset = 0
89+
90+
def process_packet(packet: List[AudioFrame]):
91+
frames_data = []
92+
rate = 0
93+
for frame in packet:
94+
frame.pts = None # 清除时间戳,避免重新采样问题
95+
resampled_frames = resampler.resample(frame) if resampler is not None else [frame]
96+
for resampled_frame in resampled_frames:
97+
frame_data = resampled_frame.to_ndarray()
98+
rate = resampled_frame.rate
99+
frames_data.append(frame_data)
100+
return (rate, frames_data)
42101

43-
try:
44-
container = av.open(file)
45-
resampler = AudioResampler(format="fltp", layout="mono", rate=sr)
102+
def frame_iter(container):
103+
for p in container.demux(container.streams.audio[0]):
104+
yield p.decode()
46105

47-
# Estimated maximum total number of samples to pre-allocate the array
48-
# AV stores length in microseconds by default
49-
estimated_total_samples = int(container.duration * sr // 1_000_000)
50-
decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32)
106+
for r, frames_data in map(process_packet, frame_iter(container)):
107+
if not rate: rate = r
108+
for frame_data in frames_data:
109+
end_index = offset + len(frame_data[0])
51110

52-
offset = 0
53-
for frame in container.decode(audio=0):
54-
frame.pts = None # Clear presentation timestamp to avoid resampling issues
55-
resampled_frames = resampler.resample(frame)
56-
for resampled_frame in resampled_frames:
57-
frame_data = resampled_frame.to_ndarray()[0]
58-
end_index = offset + len(frame_data)
111+
# 检查 decoded_audio 是否有足够的空间,并在必要时调整大小
112+
if end_index > decoded_audio.shape[1]:
113+
decoded_audio = np.resize(decoded_audio, (decoded_audio.shape[0], end_index*4))
59114

60-
# Check if decoded_audio has enough space, and resize if necessary
61-
if end_index > decoded_audio.shape[0]:
62-
decoded_audio = np.resize(decoded_audio, end_index + 1)
115+
np.copyto(decoded_audio[..., offset:end_index], frame_data)
116+
offset += len(frame_data[0])
63117

64-
decoded_audio[offset:end_index] = frame_data
65-
offset += len(frame_data)
118+
# Truncate the array to the actual size
119+
decoded_audio = decoded_audio[..., :offset]
66120

67-
# Truncate the array to the actual size
68-
decoded_audio = decoded_audio[:offset]
69-
except Exception as e:
70-
raise RuntimeError(f"Failed to load audio: {e}")
121+
if mono and decoded_audio.shape[0] > 1:
122+
decoded_audio = decoded_audio.mean(0)
71123

72-
return decoded_audio
124+
if sr is not None:
125+
return decoded_audio
126+
return decoded_audio, rate
73127

74128

75-
def downsample_audio(input_path: str, output_path: str, format: str) -> None:
129+
def downsample_audio(input_path: str, output_path: str, format: str, br=128_000) -> None:
130+
"""
131+
default to 128kb/s (equivalent to -q:a 2)
132+
"""
76133
if not os.path.exists(input_path):
77134
return
78135

@@ -83,7 +140,7 @@ def downsample_audio(input_path: str, output_path: str, format: str) -> None:
83140
input_stream = input_container.streams.audio[0]
84141
output_stream = output_container.add_stream(format)
85142

86-
output_stream.bit_rate = 128_000 # 128kb/s (equivalent to -q:a 2)
143+
output_stream.bit_rate = br
87144

88145
# Copy packets from the input file to the output file
89146
for packet in input_container.demux(input_stream):
@@ -141,7 +198,7 @@ def resample_audio(
141198
print(f"Failed to remove the original file: {e}")
142199

143200

144-
def get_audio_properties(input_path: str) -> Tuple:
201+
def get_audio_properties(input_path: str) -> Tuple[int, int]:
145202
container = av.open(input_path)
146203
audio_stream = next(s for s in container.streams if s.type == "audio")
147204
channels = 1 if audio_stream.layout == "mono" else 2

infer/lib/slicer2.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ def main():
183183
import os.path
184184
from argparse import ArgumentParser
185185

186-
import librosa
187-
import soundfile
186+
from .audio import load_audio, save_audio
188187

189188
parser = ArgumentParser()
190189
parser.add_argument("audio", type=str, help="The audio to be sliced")
@@ -230,7 +229,7 @@ def main():
230229
out = args.out
231230
if out is None:
232231
out = os.path.dirname(os.path.abspath(args.audio))
233-
audio, sr = librosa.load(args.audio, sr=None, mono=False)
232+
audio, sr = load_audio(args.audio, mono=False)
234233
slicer = Slicer(
235234
sr=sr,
236235
threshold=args.db_thresh,
@@ -245,15 +244,11 @@ def main():
245244
for i, chunk in enumerate(chunks):
246245
if len(chunk.shape) > 1:
247246
chunk = chunk.T
248-
soundfile.write(
249-
os.path.join(
250-
out,
251-
f"%s_%d.wav"
252-
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
253-
),
254-
chunk,
255-
sr,
256-
)
247+
save_audio(os.path.join(
248+
out,
249+
f"%s_%d.wav"
250+
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
251+
), chunk, sr)
257252

258253

259254
if __name__ == "__main__":

infer/lib/train/utils.py

Lines changed: 3 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -16,62 +16,12 @@
1616
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
1717
logger = logging
1818

19-
"""
20-
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
21-
assert os.path.isfile(checkpoint_path)
22-
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
23-
24-
##################
25-
def go(model, bkey):
26-
saved_state_dict = checkpoint_dict[bkey]
27-
if hasattr(model, "module"):
28-
state_dict = model.module.state_dict()
29-
else:
30-
state_dict = model.state_dict()
31-
new_state_dict = {}
32-
for k, v in state_dict.items(): # 模型需要的shape
33-
try:
34-
new_state_dict[k] = saved_state_dict[k]
35-
if saved_state_dict[k].shape != state_dict[k].shape:
36-
logger.warning(
37-
"shape-%s-mismatch. need: %s, get: %s",
38-
k,
39-
state_dict[k].shape,
40-
saved_state_dict[k].shape,
41-
) #
42-
raise KeyError
43-
except:
44-
# logger.info(traceback.format_exc())
45-
logger.info("%s is not in the checkpoint", k) # pretrain缺失的
46-
new_state_dict[k] = v # 模型自带的随机值
47-
if hasattr(model, "module"):
48-
model.module.load_state_dict(new_state_dict, strict=False)
49-
else:
50-
model.load_state_dict(new_state_dict, strict=False)
51-
return model
52-
53-
go(combd, "combd")
54-
model = go(sbd, "sbd")
55-
#############
56-
logger.info("Loaded model weights")
57-
58-
iteration = checkpoint_dict["iteration"]
59-
learning_rate = checkpoint_dict["learning_rate"]
60-
if (
61-
optimizer is not None and load_opt == 1
62-
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
63-
# try:
64-
optimizer.load_state_dict(checkpoint_dict["optimizer"])
65-
# except:
66-
# traceback.print_exc()
67-
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
68-
return model, optimizer, learning_rate, iteration
69-
"""
70-
7119

7220
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
7321
assert os.path.isfile(checkpoint_path)
74-
saved_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model"]
22+
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
23+
24+
saved_state_dict = checkpoint_dict["model"]
7525
if hasattr(model, "module"):
7626
state_dict = model.module.state_dict()
7727
else:
@@ -132,34 +82,6 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
13282
)
13383

13484

135-
"""
136-
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
137-
logger.info(
138-
"Saving model and optimizer state at epoch {} to {}".format(
139-
iteration, checkpoint_path
140-
)
141-
)
142-
if hasattr(combd, "module"):
143-
state_dict_combd = combd.module.state_dict()
144-
else:
145-
state_dict_combd = combd.state_dict()
146-
if hasattr(sbd, "module"):
147-
state_dict_sbd = sbd.module.state_dict()
148-
else:
149-
state_dict_sbd = sbd.state_dict()
150-
torch.save(
151-
{
152-
"combd": state_dict_combd,
153-
"sbd": state_dict_sbd,
154-
"iteration": iteration,
155-
"optimizer": optimizer.state_dict(),
156-
"learning_rate": learning_rate,
157-
},
158-
checkpoint_path,
159-
)
160-
"""
161-
162-
16385
def summarize(
16486
writer,
16587
global_step,
@@ -366,53 +288,6 @@ def get_hparams(init=True):
366288
return hparams
367289

368290

369-
"""
370-
def get_hparams_from_dir(model_dir):
371-
config_save_path = os.path.join(model_dir, "config.json")
372-
with open(config_save_path, "r") as f:
373-
data = f.read()
374-
config = json.loads(data)
375-
376-
hparams = HParams(**config)
377-
hparams.model_dir = model_dir
378-
return hparams
379-
380-
381-
def get_hparams_from_file(config_path):
382-
with open(config_path, "r") as f:
383-
data = f.read()
384-
config = json.loads(data)
385-
386-
hparams = HParams(**config)
387-
return hparams
388-
389-
390-
def check_git_hash(model_dir):
391-
source_dir = os.path.dirname(os.path.realpath(__file__))
392-
if not os.path.exists(os.path.join(source_dir, ".git")):
393-
logger.warning(
394-
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
395-
source_dir
396-
)
397-
)
398-
return
399-
400-
cur_hash = subprocess.getoutput("git rev-parse HEAD")
401-
402-
path = os.path.join(model_dir, "githash")
403-
if os.path.exists(path):
404-
saved_hash = open(path).read()
405-
if saved_hash != cur_hash:
406-
logger.warning(
407-
"git hash values are different. {}(saved) != {}(current)".format(
408-
saved_hash[:8], cur_hash[:8]
409-
)
410-
)
411-
else:
412-
open(path, "w").write(cur_hash)
413-
"""
414-
415-
416291
def get_logger(model_dir, filename="train.log"):
417292
global logger
418293
logger = logging.getLogger(os.path.basename(model_dir))

infer/modules/train/extract_feature_print.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
import sys
33
import traceback
44

5+
now_dir = os.getcwd()
6+
sys.path.append(now_dir)
7+
8+
from infer.lib.audio import load_audio
9+
510
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
611
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
712

@@ -20,7 +25,6 @@
2025
is_half = sys.argv[7].lower() == "true"
2126
import fairseq
2227
import numpy as np
23-
import soundfile as sf
2428
import torch
2529
import torch.nn.functional as F
2630

@@ -64,11 +68,9 @@ def printt(strr):
6468

6569
# wave must be 16k, hop_size=320
6670
def readwave(wav_path, normalize=False):
67-
wav, sr = sf.read(wav_path)
71+
wav, sr = load_audio(wav_path)
6872
assert sr == 16000
6973
feats = torch.from_numpy(wav).float()
70-
if feats.dim() == 2: # double channels
71-
feats = feats.mean(-1)
7274
assert feats.dim() == 1, feats.dim()
7375
if normalize:
7476
with torch.no_grad():

0 commit comments

Comments
 (0)