|
1 | 1 | from io import BufferedWriter, BytesIO |
2 | 2 | from pathlib import Path |
3 | | -from typing import Dict, Tuple, Optional, Union |
| 3 | +from typing import Dict, Tuple, Optional, Union, List |
4 | 4 | import os |
5 | 5 | import math |
6 | 6 | import wave |
|
9 | 9 | from numba import jit |
10 | 10 | import av |
11 | 11 | from av.audio.resampler import AudioResampler |
| 12 | +from av.audio.frame import AudioFrame |
12 | 13 |
|
13 | 14 | video_format_dict: Dict[str, str] = { |
14 | 15 | "m4a": "mp4", |
@@ -66,35 +67,47 @@ def load_audio(file: Union[str, BytesIO, Path], sr: Optional[int]=None, format: |
66 | 67 | if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()): |
67 | 68 | raise FileNotFoundError(f"File not found: {file}") |
68 | 69 | rate = 0 |
69 | | - try: |
70 | | - container = av.open(file, format=format) |
71 | | - resampler = AudioResampler(format="fltp", layout="mono", rate=sr) |
72 | 70 |
|
73 | | - # Estimated maximum total number of samples to pre-allocate the array |
74 | | - # AV stores length in microseconds by default |
75 | | - estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000 |
76 | | - decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32) |
| 71 | + container = av.open(file, format=format) |
| 72 | + audio_stream = next(s for s in container.streams if s.type == "audio") |
| 73 | + channels = 1 if audio_stream.layout == "mono" else 2 |
| 74 | + container.seek(0) |
| 75 | + resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr) |
77 | 76 |
|
78 | | - offset = 0 |
79 | | - for frame in container.decode(audio=0): |
80 | | - frame.pts = None # Clear presentation timestamp to avoid resampling issues |
| 77 | + # Estimated maximum total number of samples to pre-allocate the array |
| 78 | + # AV stores length in microseconds by default |
| 79 | + estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000 |
| 80 | + decoded_audio = np.zeros(estimated_total_samples + 1 if channels == 1 else (channels, estimated_total_samples + 1), dtype=np.float32) |
| 81 | + |
| 82 | + offset = 0 |
| 83 | + |
| 84 | + def process_packet(packet: List[AudioFrame]): |
| 85 | + frames_data = [] |
| 86 | + for frame in packet: |
| 87 | + frame.pts = None # 清除时间戳,避免重新采样问题 |
81 | 88 | resampled_frames = resampler.resample(frame) |
82 | 89 | for resampled_frame in resampled_frames: |
83 | | - rate = resampled_frame.rate |
84 | | - frame_data = resampled_frame.to_ndarray()[0] |
85 | | - end_index = offset + len(frame_data) |
| 90 | + frame_data = resampled_frame.to_ndarray() |
| 91 | + frames_data.append(frame_data) |
| 92 | + return frames_data |
86 | 93 |
|
87 | | - # Check if decoded_audio has enough space, and resize if necessary |
88 | | - if end_index > decoded_audio.shape[0]: |
89 | | - decoded_audio = np.resize(decoded_audio, end_index + 1) |
| 94 | + def frame_iter(container): |
| 95 | + for p in container.demux(container.streams.audio[0]): |
| 96 | + yield p.decode() |
90 | 97 |
|
91 | | - decoded_audio[offset:end_index] = frame_data |
92 | | - offset += len(frame_data) |
| 98 | + for frames_data in map(process_packet, frame_iter(container)): |
| 99 | + for frame_data in frames_data: |
| 100 | + end_index = offset + len(frame_data[0]) |
93 | 101 |
|
94 | | - # Truncate the array to the actual size |
95 | | - decoded_audio = decoded_audio[:offset] |
96 | | - except Exception as e: |
97 | | - raise RuntimeError(f"Failed to load audio: {e}") |
| 102 | + # 检查 decoded_audio 是否有足够的空间,并在必要时调整大小 |
| 103 | + if end_index > decoded_audio.shape[1]: |
| 104 | + decoded_audio = np.resize(decoded_audio, (decoded_audio.shape[0], end_index*4)) |
| 105 | + |
| 106 | + np.copyto(decoded_audio[..., offset:end_index], frame_data) |
| 107 | + offset += len(frame_data[0]) |
| 108 | + |
| 109 | + # Truncate the array to the actual size |
| 110 | + decoded_audio = decoded_audio[..., :offset] |
98 | 111 |
|
99 | 112 | if sr is not None: |
100 | 113 | return decoded_audio |
@@ -173,7 +186,7 @@ def resample_audio( |
173 | 186 | print(f"Failed to remove the original file: {e}") |
174 | 187 |
|
175 | 188 |
|
176 | | -def get_audio_properties(input_path: str) -> Tuple: |
| 189 | +def get_audio_properties(input_path: str) -> Tuple[int, int]: |
177 | 190 | container = av.open(input_path) |
178 | 191 | audio_stream = next(s for s in container.streams if s.type == "audio") |
179 | 192 | channels = 1 if audio_stream.layout == "mono" else 2 |
|
0 commit comments