Skip to content

Commit e0fffe0

Browse files
committed
fix(audio): too many mallocs
1 parent 0bd8bc4 commit e0fffe0

File tree

1 file changed

+37
-24
lines changed

1 file changed

+37
-24
lines changed

infer/lib/audio.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from io import BufferedWriter, BytesIO
22
from pathlib import Path
3-
from typing import Dict, Tuple, Optional, Union
3+
from typing import Dict, Tuple, Optional, Union, List
44
import os
55
import math
66
import wave
@@ -9,6 +9,7 @@
99
from numba import jit
1010
import av
1111
from av.audio.resampler import AudioResampler
12+
from av.audio.frame import AudioFrame
1213

1314
video_format_dict: Dict[str, str] = {
1415
"m4a": "mp4",
@@ -66,35 +67,47 @@ def load_audio(file: Union[str, BytesIO, Path], sr: Optional[int]=None, format:
6667
if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()):
6768
raise FileNotFoundError(f"File not found: {file}")
6869
rate = 0
69-
try:
70-
container = av.open(file, format=format)
71-
resampler = AudioResampler(format="fltp", layout="mono", rate=sr)
7270

73-
# Estimated maximum total number of samples to pre-allocate the array
74-
# AV stores length in microseconds by default
75-
estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000
76-
decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32)
71+
container = av.open(file, format=format)
72+
audio_stream = next(s for s in container.streams if s.type == "audio")
73+
channels = 1 if audio_stream.layout == "mono" else 2
74+
container.seek(0)
75+
resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr)
7776

78-
offset = 0
79-
for frame in container.decode(audio=0):
80-
frame.pts = None # Clear presentation timestamp to avoid resampling issues
77+
# Estimated maximum total number of samples to pre-allocate the array
78+
# AV stores length in microseconds by default
79+
estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000
80+
decoded_audio = np.zeros(estimated_total_samples + 1 if channels == 1 else (channels, estimated_total_samples + 1), dtype=np.float32)
81+
82+
offset = 0
83+
84+
def process_packet(packet: List[AudioFrame]):
85+
frames_data = []
86+
for frame in packet:
87+
frame.pts = None # 清除时间戳,避免重新采样问题
8188
resampled_frames = resampler.resample(frame)
8289
for resampled_frame in resampled_frames:
83-
rate = resampled_frame.rate
84-
frame_data = resampled_frame.to_ndarray()[0]
85-
end_index = offset + len(frame_data)
90+
frame_data = resampled_frame.to_ndarray()
91+
frames_data.append(frame_data)
92+
return frames_data
8693

87-
# Check if decoded_audio has enough space, and resize if necessary
88-
if end_index > decoded_audio.shape[0]:
89-
decoded_audio = np.resize(decoded_audio, end_index + 1)
94+
def frame_iter(container):
95+
for p in container.demux(container.streams.audio[0]):
96+
yield p.decode()
9097

91-
decoded_audio[offset:end_index] = frame_data
92-
offset += len(frame_data)
98+
for frames_data in map(process_packet, frame_iter(container)):
99+
for frame_data in frames_data:
100+
end_index = offset + len(frame_data[0])
93101

94-
# Truncate the array to the actual size
95-
decoded_audio = decoded_audio[:offset]
96-
except Exception as e:
97-
raise RuntimeError(f"Failed to load audio: {e}")
102+
# 检查 decoded_audio 是否有足够的空间,并在必要时调整大小
103+
if end_index > decoded_audio.shape[1]:
104+
decoded_audio = np.resize(decoded_audio, (decoded_audio.shape[0], end_index*4))
105+
106+
np.copyto(decoded_audio[..., offset:end_index], frame_data)
107+
offset += len(frame_data[0])
108+
109+
# Truncate the array to the actual size
110+
decoded_audio = decoded_audio[..., :offset]
98111

99112
if sr is not None:
100113
return decoded_audio
@@ -173,7 +186,7 @@ def resample_audio(
173186
print(f"Failed to remove the original file: {e}")
174187

175188

176-
def get_audio_properties(input_path: str) -> Tuple:
189+
def get_audio_properties(input_path: str) -> Tuple[int, int]:
177190
container = av.open(input_path)
178191
audio_stream = next(s for s in container.streams if s.type == "audio")
179192
channels = 1 if audio_stream.layout == "mono" else 2

0 commit comments

Comments
 (0)