Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 87 additions & 30 deletions infer/lib/audio.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from io import BufferedWriter, BytesIO
from pathlib import Path
from typing import Dict, Tuple
from typing import Dict, Tuple, Optional, Union, List
import os
import math
import wave

import numpy as np
from numba import jit
import av
from av.audio.resampler import AudioResampler
from av.audio.frame import AudioFrame
import scipy.io.wavfile as wavfile

video_format_dict: Dict[str, str] = {
"m4a": "mp4",
Expand All @@ -17,6 +22,29 @@
}


@jit(nopython=True)
def float_to_int16(audio: np.ndarray) -> np.ndarray:
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
am = 32767 * 32768 // am
return np.multiply(audio, am).astype(np.int16)

def float_np_array_to_wav_buf(wav: np.ndarray, sr: int, f32=False) -> BytesIO:
buf = BytesIO()
if f32:
wavfile.write(buf, sr, wav.astype(np.float32))
else:
with wave.open(buf, "wb") as wf:
wf.setnchannels(2 if len(wav.shape) > 1 else 1)
wf.setsampwidth(2) # Sample width in bytes
wf.setframerate(sr) # Sample rate in Hz
wf.writeframes(float_to_int16(wav.T if len(wav.shape) > 1 else wav))
buf.seek(0, 0)
return buf

def save_audio(path: str, audio: np.ndarray, sr: int, f32=False):
with open(path, "wb") as f:
f.write(float_np_array_to_wav_buf(audio, sr, f32).getbuffer())

def wav2(i: BytesIO, o: BufferedWriter, format: str):
inp = av.open(i, "r")
format = video_format_dict.get(format, format)
Expand All @@ -36,43 +64,72 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
inp.close()


def load_audio(file: str, sr: int) -> np.ndarray:
if not Path(file).exists():
def load_audio(
file: Union[str, BytesIO, Path],
sr: Optional[int]=None,
format: Optional[str]=None,
mono=True
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()):
raise FileNotFoundError(f"File not found: {file}")
rate = 0

container = av.open(file, format=format)
audio_stream = next(s for s in container.streams if s.type == "audio")
channels = 1 if audio_stream.layout == "mono" else 2
container.seek(0)
resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr) if sr is not None else None

# Estimated maximum total number of samples to pre-allocate the array
# AV stores length in microseconds by default
estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000
decoded_audio = np.zeros(estimated_total_samples + 1 if channels == 1 else (channels, estimated_total_samples + 1), dtype=np.float32)

offset = 0

def process_packet(packet: List[AudioFrame]):
frames_data = []
rate = 0
for frame in packet:
frame.pts = None # 清除时间戳,避免重新采样问题
resampled_frames = resampler.resample(frame) if resampler is not None else [frame]
for resampled_frame in resampled_frames:
frame_data = resampled_frame.to_ndarray()
rate = resampled_frame.rate
frames_data.append(frame_data)
return (rate, frames_data)

try:
container = av.open(file)
resampler = AudioResampler(format="fltp", layout="mono", rate=sr)
def frame_iter(container):
for p in container.demux(container.streams.audio[0]):
yield p.decode()

# Estimated maximum total number of samples to pre-allocate the array
# AV stores length in microseconds by default
estimated_total_samples = int(container.duration * sr // 1_000_000)
decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32)
for r, frames_data in map(process_packet, frame_iter(container)):
if not rate: rate = r
for frame_data in frames_data:
end_index = offset + len(frame_data[0])

offset = 0
for frame in container.decode(audio=0):
frame.pts = None # Clear presentation timestamp to avoid resampling issues
resampled_frames = resampler.resample(frame)
for resampled_frame in resampled_frames:
frame_data = resampled_frame.to_ndarray()[0]
end_index = offset + len(frame_data)
# 检查 decoded_audio 是否有足够的空间,并在必要时调整大小
if end_index > decoded_audio.shape[1]:
decoded_audio = np.resize(decoded_audio, (decoded_audio.shape[0], end_index*4))

# Check if decoded_audio has enough space, and resize if necessary
if end_index > decoded_audio.shape[0]:
decoded_audio = np.resize(decoded_audio, end_index + 1)
np.copyto(decoded_audio[..., offset:end_index], frame_data)
offset += len(frame_data[0])

decoded_audio[offset:end_index] = frame_data
offset += len(frame_data)
# Truncate the array to the actual size
decoded_audio = decoded_audio[..., :offset]

# Truncate the array to the actual size
decoded_audio = decoded_audio[:offset]
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")
if mono and decoded_audio.shape[0] > 1:
decoded_audio = decoded_audio.mean(0)

return decoded_audio
if sr is not None:
return decoded_audio
return decoded_audio, rate


def downsample_audio(input_path: str, output_path: str, format: str) -> None:
def downsample_audio(input_path: str, output_path: str, format: str, br=128_000) -> None:
"""
default to 128kb/s (equivalent to -q:a 2)
"""
if not os.path.exists(input_path):
return

Expand All @@ -83,7 +140,7 @@ def downsample_audio(input_path: str, output_path: str, format: str) -> None:
input_stream = input_container.streams.audio[0]
output_stream = output_container.add_stream(format)

output_stream.bit_rate = 128_000 # 128kb/s (equivalent to -q:a 2)
output_stream.bit_rate = br

# Copy packets from the input file to the output file
for packet in input_container.demux(input_stream):
Expand Down Expand Up @@ -141,7 +198,7 @@ def resample_audio(
print(f"Failed to remove the original file: {e}")


def get_audio_properties(input_path: str) -> Tuple:
def get_audio_properties(input_path: str) -> Tuple[int, int]:
container = av.open(input_path)
audio_stream = next(s for s in container.streams if s.type == "audio")
channels = 1 if audio_stream.layout == "mono" else 2
Expand Down
19 changes: 7 additions & 12 deletions infer/lib/slicer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ def main():
import os.path
from argparse import ArgumentParser

import librosa
import soundfile
from .audio import load_audio, save_audio

parser = ArgumentParser()
parser.add_argument("audio", type=str, help="The audio to be sliced")
Expand Down Expand Up @@ -230,7 +229,7 @@ def main():
out = args.out
if out is None:
out = os.path.dirname(os.path.abspath(args.audio))
audio, sr = librosa.load(args.audio, sr=None, mono=False)
audio, sr = load_audio(args.audio, mono=False)
slicer = Slicer(
sr=sr,
threshold=args.db_thresh,
Expand All @@ -245,15 +244,11 @@ def main():
for i, chunk in enumerate(chunks):
if len(chunk.shape) > 1:
chunk = chunk.T
soundfile.write(
os.path.join(
out,
f"%s_%d.wav"
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
),
chunk,
sr,
)
save_audio(os.path.join(
out,
f"%s_%d.wav"
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
), chunk, sr)


if __name__ == "__main__":
Expand Down
131 changes: 3 additions & 128 deletions infer/lib/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,12 @@
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging

"""
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")

##################
def go(model, bkey):
saved_state_dict = checkpoint_dict[bkey]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
logger.warning(
"shape-%s-mismatch. need: %s, get: %s",
k,
state_dict[k].shape,
saved_state_dict[k].shape,
) #
raise KeyError
except:
# logger.info(traceback.format_exc())
logger.info("%s is not in the checkpoint", k) # pretrain缺失的
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
return model

go(combd, "combd")
model = go(sbd, "sbd")
#############
logger.info("Loaded model weights")

iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None and load_opt == 1
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
# try:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
# except:
# traceback.print_exc()
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
"""


def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
saved_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model"]
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
Expand Down Expand Up @@ -132,34 +82,6 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
)


"""
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at epoch {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(combd, "module"):
state_dict_combd = combd.module.state_dict()
else:
state_dict_combd = combd.state_dict()
if hasattr(sbd, "module"):
state_dict_sbd = sbd.module.state_dict()
else:
state_dict_sbd = sbd.state_dict()
torch.save(
{
"combd": state_dict_combd,
"sbd": state_dict_sbd,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
"""


def summarize(
writer,
global_step,
Expand Down Expand Up @@ -366,53 +288,6 @@ def get_hparams(init=True):
return hparams


"""
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)

hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams


def get_hparams_from_file(config_path):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)

hparams = HParams(**config)
return hparams


def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warning(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
return

cur_hash = subprocess.getoutput("git rev-parse HEAD")

path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warning(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else:
open(path, "w").write(cur_hash)
"""


def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
Expand Down
10 changes: 6 additions & 4 deletions infer/modules/train/extract_feature_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import sys
import traceback

now_dir = os.getcwd()
sys.path.append(now_dir)

from infer.lib.audio import load_audio

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

Expand All @@ -20,7 +25,6 @@
is_half = sys.argv[7].lower() == "true"
import fairseq
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -64,11 +68,9 @@ def printt(strr):

# wave must be 16k, hop_size=320
def readwave(wav_path, normalize=False):
wav, sr = sf.read(wav_path)
wav, sr = load_audio(wav_path)
assert sr == 16000
feats = torch.from_numpy(wav).float()
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
if normalize:
with torch.no_grad():
Expand Down
Loading