diff --git a/record_and_predict.py b/record_and_predict.py index f99e676..d5ba55f 100644 --- a/record_and_predict.py +++ b/record_and_predict.py @@ -104,6 +104,7 @@ def record_and_predict(): # Segment assembly state segment = [] # list of float32 chunks (includes pre, speech, trailing silence) speech_active = False + turn_started = False trailing_silence = 0 since_trigger_chunks = 0 @@ -130,15 +131,23 @@ def record_and_predict(): is_speech = vad.prob(f32) > VAD_THRESHOLD if not speech_active: - # Warmup pre-speech buffer until trigger - pre_buffer.append(f32) + # It could mean two cases: + # 1. There are already some VAD spans, but it haven't reach a turn end + # 2. No VAD span has been identified + if not turn_started and not is_speech: + # Warmup pre-speech buffer until trigger + pre_buffer.append(f32) if is_speech: - # Trigger: start a new segment with pre-speech - segment = list(pre_buffer) + if not turn_started: + # Trigger: start a new segment with pre-speech + segment = list(pre_buffer) + since_trigger_chunks = 0 segment.append(f32) + since_trigger_chunks += 1 speech_active = True trailing_silence = 0 - since_trigger_chunks = 1 + turn_started = True + print(">> VAD span started") else: # Already in a segment segment.append(f32) @@ -152,13 +161,16 @@ def record_and_predict(): if trailing_silence >= stop_chunks or since_trigger_chunks >= max_chunks: # Pause capture while we process stream.stop_stream() - _process_segment(np.concatenate(segment, dtype=np.float32)) - # Reset for next segment - segment.clear() + print(">> VAD span ended") + if _process_segment(np.concatenate(segment, dtype=np.float32)): + # Reset for next segment + segment.clear() + trailing_silence = 0 + since_trigger_chunks = 0 + pre_buffer.clear() + turn_started = False + print("[WARN] Turn ended") speech_active = False - trailing_silence = 0 - since_trigger_chunks = 0 - pre_buffer.clear() stream.start_stream() print("Listening for speech...") @@ -173,7 +185,7 @@ def record_and_predict(): def _process_segment(segment_audio_f32: np.ndarray): if segment_audio_f32.size == 0: print("Captured empty audio segment, skipping prediction.") - return + return 0 if DEBUG_SAVE_WAV: wavfile.write(TEMP_OUTPUT_WAV, RATE, (segment_audio_f32 * 32767.0).astype(np.int16)) @@ -192,6 +204,7 @@ def _process_segment(segment_audio_f32: np.ndarray): print(f"Prediction: {'Complete' if pred == 1 else 'Incomplete'}") print(f"Probability of complete: {prob:.4f}") print(f"Inference time: {dt_ms:.2f} ms") + return pred if __name__ == "__main__":