Skip to content

Commit c59b482

Browse files
authored
fix aws credentials when using env vals (#4403)
1 parent 4fc3319 commit c59b482

File tree

1 file changed

+28
-11
lines changed
  • livekit-plugins/livekit-plugins-aws/livekit/plugins/aws

1 file changed

+28
-11
lines changed

livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
TranscriptResultStream,
4848
)
4949
from smithy_aws_core.identity import (
50+
AWSCredentialsIdentity,
5051
ContainerCredentialsResolver,
5152
EnvironmentCredentialsResolver,
5253
IMDSCredentialsResolver,
@@ -195,9 +196,22 @@ async def _run(self) -> None:
195196
while True:
196197
config_kwargs: dict[str, Any] = {"region": self._opts.region}
197198
if self._credentials:
198-
config_kwargs["aws_access_key_id"] = self._credentials.access_key_id
199-
config_kwargs["aws_secret_access_key"] = self._credentials.secret_access_key
200-
config_kwargs["aws_session_token"] = self._credentials.session_token
199+
# Use a credentials resolver for explicit credentials
200+
# for some reason, Config with direct values doesn't work
201+
class StaticCredsResolver:
202+
def __init__(self, creds: Credentials):
203+
self._identity = AWSCredentialsIdentity(
204+
access_key_id=creds.access_key_id,
205+
secret_access_key=creds.secret_access_key,
206+
session_token=creds.session_token,
207+
)
208+
209+
async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity:
210+
return self._identity
211+
212+
config_kwargs["aws_credentials_identity_resolver"] = StaticCredsResolver(
213+
self._credentials
214+
)
201215
else:
202216
config_kwargs["aws_credentials_identity_resolver"] = ChainedIdentityResolver(
203217
resolvers=(
@@ -229,6 +243,8 @@ async def _run(self) -> None:
229243
}
230244
filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
231245

246+
tasks: list[asyncio.Task[Any]] = []
247+
232248
try:
233249
stream = await client.start_stream_transcription(
234250
input=StartStreamTranscriptionInput(**filtered_config)
@@ -286,14 +302,15 @@ async def handle_transcript_events(
286302
else:
287303
raise e
288304
finally:
289-
# Close input stream first
290-
await utils.aio.gracefully_cancel(tasks[0])
291-
292-
# Wait for output stream to close cleanly
293-
try:
294-
await asyncio.wait_for(tasks[1], timeout=3.0)
295-
except (asyncio.TimeoutError, asyncio.CancelledError):
296-
await utils.aio.gracefully_cancel(tasks[1])
305+
if tasks:
306+
# Close input stream first
307+
await utils.aio.gracefully_cancel(tasks[0])
308+
309+
# Wait for output stream to close cleanly
310+
try:
311+
await asyncio.wait_for(tasks[1], timeout=3.0)
312+
except (asyncio.TimeoutError, asyncio.CancelledError):
313+
await utils.aio.gracefully_cancel(tasks[1])
297314

298315
# Ensure gather future is retrieved to avoid "exception never retrieved"
299316
with contextlib.suppress(Exception):

0 commit comments

Comments
 (0)