Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions record_and_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def record_and_predict():
# Segment assembly state
segment = [] # list of float32 chunks (includes pre, speech, trailing silence)
speech_active = False
turn_started = False
trailing_silence = 0
since_trigger_chunks = 0

Expand All @@ -130,15 +131,23 @@ def record_and_predict():
is_speech = vad.prob(f32) > VAD_THRESHOLD

if not speech_active:
# Warmup pre-speech buffer until trigger
pre_buffer.append(f32)
# It could mean two cases:
# 1. There are already some VAD spans, but it haven't reach a turn end
# 2. No VAD span has been identified
if not turn_started and not is_speech:
# Warmup pre-speech buffer until trigger
pre_buffer.append(f32)
if is_speech:
# Trigger: start a new segment with pre-speech
segment = list(pre_buffer)
if not turn_started:
# Trigger: start a new segment with pre-speech
segment = list(pre_buffer)
since_trigger_chunks = 0
segment.append(f32)
since_trigger_chunks += 1
speech_active = True
trailing_silence = 0
since_trigger_chunks = 1
turn_started = True
print(">> VAD span started")
else:
# Already in a segment
segment.append(f32)
Expand All @@ -152,13 +161,16 @@ def record_and_predict():
if trailing_silence >= stop_chunks or since_trigger_chunks >= max_chunks:
# Pause capture while we process
stream.stop_stream()
_process_segment(np.concatenate(segment, dtype=np.float32))
# Reset for next segment
segment.clear()
print(">> VAD span ended")
if _process_segment(np.concatenate(segment, dtype=np.float32)):
# Reset for next segment
segment.clear()
trailing_silence = 0
since_trigger_chunks = 0
pre_buffer.clear()
turn_started = False
print("[WARN] Turn ended")
speech_active = False
trailing_silence = 0
since_trigger_chunks = 0
pre_buffer.clear()
stream.start_stream()
print("Listening for speech...")

Expand All @@ -173,7 +185,7 @@ def record_and_predict():
def _process_segment(segment_audio_f32: np.ndarray):
if segment_audio_f32.size == 0:
print("Captured empty audio segment, skipping prediction.")
return
return 0

if DEBUG_SAVE_WAV:
wavfile.write(TEMP_OUTPUT_WAV, RATE, (segment_audio_f32 * 32767.0).astype(np.int16))
Expand All @@ -192,6 +204,7 @@ def _process_segment(segment_audio_f32: np.ndarray):
print(f"Prediction: {'Complete' if pred == 1 else 'Incomplete'}")
print(f"Probability of complete: {prob:.4f}")
print(f"Inference time: {dt_ms:.2f} ms")
return pred


if __name__ == "__main__":
Expand Down