diff --git a/api/run_eval.py b/api/run_eval.py index 92ded64..b9a2457 100644 --- a/api/run_eval.py +++ b/api/run_eval.py @@ -198,23 +198,38 @@ def transcribe_with_retry( elif model_name.startswith("elevenlabs/"): client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY")) + + api_params = { + "model_id": model_name.split("/")[1], + "language_code": "en", + "diarize": True, + "timestamps_granularity": "word", + } + if use_url: response = requests.get(sample["row"]["audio"][0]["src"]) audio_data = BytesIO(response.content) - transcription = client.speech_to_text.convert( - file=audio_data, - model_id=model_name.split("/")[1], - language_code="eng", - tag_audio_events=True, - ) + transcription = client.speech_to_text.convert(file=audio_data, **api_params) else: with open(audio_file_path, "rb") as audio_file: - transcription = client.speech_to_text.convert( - file=audio_file, - model_id=model_name.split("/")[1], - language_code="eng", - tag_audio_events=True, - ) + transcription = client.speech_to_text.convert(file=audio_file, **api_params) + + if hasattr(transcription, 'words') and transcription.words: + speaker_word_counts = {} + speaker_words = {} + + for word_obj in transcription.words: + if hasattr(word_obj, 'speaker_id') and word_obj.speaker_id: + speaker_id = word_obj.speaker_id + speaker_word_counts[speaker_id] = speaker_word_counts.get(speaker_id, 0) + 1 + if speaker_id not in speaker_words: + speaker_words[speaker_id] = [] + speaker_words[speaker_id].append(word_obj.text) + + if speaker_word_counts: + dominant_speaker = max(speaker_word_counts, key=speaker_word_counts.get) + return " ".join(speaker_words[dominant_speaker]) + return transcription.text elif model_name.startswith("revai/"):