Skip to content

Commit cd728e6

Browse files
authored
Support more input transcription parameters for openai realtime (#1637)
1 parent a228f6c commit cd728e6

File tree

4 files changed

+30
-0
lines changed

4 files changed

+30
-0
lines changed

.changeset/fresh-foxes-remember.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"livekit-plugins-openai": patch
3+
---
4+
5+
Support more input transcription parameters for openai realtime

livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/api_proto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class ContentPart(TypedDict):
6969

7070
class InputAudioTranscription(TypedDict):
7171
model: InputTranscriptionModel | str
72+
language: NotRequired[str]
73+
prompt: NotRequired[str]
7274

7375

7476
class ServerVad(TypedDict):

livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ class ServerVadOptions:
149149
@dataclass
150150
class InputTranscriptionOptions:
151151
model: api_proto.InputTranscriptionModel | str
152+
language: str | None = None
153+
prompt: str | None = None
152154

153155

154156
@dataclass
@@ -976,6 +978,14 @@ def session_update(
976978
input_audio_transcription_opts = {
977979
"model": self._opts.input_audio_transcription.model,
978980
}
981+
if self._opts.input_audio_transcription.language is not None:
982+
input_audio_transcription_opts["language"] = (
983+
self._opts.input_audio_transcription.language
984+
)
985+
if self._opts.input_audio_transcription.prompt is not None:
986+
input_audio_transcription_opts["prompt"] = (
987+
self._opts.input_audio_transcription.prompt
988+
)
979989

980990
session_data: api_proto.ClientEvent.SessionUpdateData = {
981991
"modalities": self._opts.modalities,
@@ -1296,6 +1306,8 @@ def _handle_session_updated(
12961306
else:
12971307
input_audio_transcription = InputTranscriptionOptions(
12981308
model=session["input_audio_transcription"]["model"],
1309+
language=session["input_audio_transcription"].get("language"),
1310+
prompt=session["input_audio_transcription"].get("prompt"),
12991311
)
13001312

13011313
self.emit(

livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class _STTOptions:
3939
language: str
4040
detect_language: bool
4141
model: WhisperModels | str
42+
prompt: str | None = None
4243

4344

4445
class STT(stt.STT):
@@ -48,6 +49,7 @@ def __init__(
4849
language: str = "en",
4950
detect_language: bool = False,
5051
model: WhisperModels | str = "whisper-1",
52+
prompt: str | None = None,
5153
base_url: str | None = None,
5254
api_key: str | None = None,
5355
client: openai.AsyncClient | None = None,
@@ -69,6 +71,7 @@ def __init__(
6971
language=language,
7072
detect_language=detect_language,
7173
model=model,
74+
prompt=prompt,
7275
)
7376

7477
self._client = client or openai.AsyncClient(
@@ -91,9 +94,11 @@ def update_options(
9194
*,
9295
model: WhisperModels | GroqAudioModels | None = None,
9396
language: str | None = None,
97+
prompt: str | None = None,
9498
) -> None:
9599
self._opts.model = model or self._opts.model
96100
self._opts.language = language or self._opts.language
101+
self._opts.prompt = prompt or self._opts.prompt
97102

98103
@staticmethod
99104
def with_groq(
@@ -103,6 +108,7 @@ def with_groq(
103108
base_url: str | None = "https://api.groq.com/openai/v1",
104109
client: openai.AsyncClient | None = None,
105110
language: str = "en",
111+
prompt: str | None = None,
106112
detect_language: bool = False,
107113
) -> STT:
108114
"""
@@ -123,6 +129,7 @@ def with_groq(
123129
client=client,
124130
language=language,
125131
detect_language=detect_language,
132+
prompt=prompt,
126133
)
127134

128135
def _sanitize_options(self, *, language: str | None = None) -> _STTOptions:
@@ -140,6 +147,9 @@ async def _recognize_impl(
140147
try:
141148
config = self._sanitize_options(language=language)
142149
data = rtc.combine_audio_frames(buffer).to_wav_bytes()
150+
prompt = (
151+
self._opts.prompt if self._opts.prompt is not None else openai.NOT_GIVEN
152+
)
143153
resp = await self._client.audio.transcriptions.create(
144154
file=(
145155
"file.wav",
@@ -148,6 +158,7 @@ async def _recognize_impl(
148158
),
149159
model=self._opts.model,
150160
language=config.language,
161+
prompt=prompt,
151162
# verbose_json returns language and other details
152163
response_format="verbose_json",
153164
timeout=httpx.Timeout(30, connect=conn_options.timeout),

0 commit comments

Comments
 (0)