|
1 | 1 | import time |
| 2 | +from enum import Enum |
2 | 3 | from os import getenv |
3 | 4 |
|
4 | 5 | import azure.cognitiveservices.speech as speechsdk |
|
7 | 8 |
|
8 | 9 | load_dotenv() |
9 | 10 |
|
| 11 | + |
| 12 | +class ServiceType(Enum): |
| 13 | + Azure = "Azure" |
| 14 | + Local = "Local" |
| 15 | + |
| 16 | + |
| 17 | +class TranscriptionStatus(Enum): |
| 18 | + NotStarted = "Not started" |
| 19 | + InProgress = "In progress" |
| 20 | + Completed = "Completed" |
| 21 | + |
| 22 | + |
| 23 | +if "transcription_status" not in st.session_state: |
| 24 | + st.session_state.transcription_status = TranscriptionStatus.NotStarted |
| 25 | + |
| 26 | + |
10 | 27 | with st.sidebar: |
| 28 | + speech_recognition_language = st.selectbox( |
| 29 | + label="Speech recognition language", |
| 30 | + options=[ |
| 31 | + "en-US", |
| 32 | + "ja-JP", |
| 33 | + "zh-CN", |
| 34 | + ], |
| 35 | + ) |
11 | 36 | service_type = st.selectbox( |
12 | 37 | label="Service type", |
13 | 38 | options=[ |
14 | | - "Local", |
15 | | - "Azure", |
| 39 | + ServiceType.Local.value, |
| 40 | + ServiceType.Azure.value, |
16 | 41 | ], |
17 | 42 | ) |
18 | | - if service_type == "Azure": |
19 | | - azure_ai_services_api_key = st.text_input( |
20 | | - label="AZURE_AI_SERVICES_API_KEY", |
21 | | - key="AZURE_AI_SERVICES_API_KEY", |
| 43 | + if service_type == ServiceType.Azure.value: |
| 44 | + azure_ai_speech_api_subscription_key = st.text_input( |
| 45 | + label="AZURE_AI_SPEECH_API_SUBSCRIPTION_KEY", |
| 46 | + value=getenv("AZURE_AI_SPEECH_API_SUBSCRIPTION_KEY"), |
| 47 | + key="AZURE_AI_SPEECH_API_SUBSCRIPTION_KEY", |
22 | 48 | type="password", |
23 | 49 | ) |
24 | 50 | azure_ai_speech_region = st.text_input( |
|
27 | 53 | key="AZURE_AI_SPEECH_REGION", |
28 | 54 | type="default", |
29 | 55 | ) |
30 | | - if service_type == "Local": |
| 56 | + if service_type == ServiceType.Local.value: |
31 | 57 | host = st.text_input( |
32 | 58 | label="Host", |
33 | 59 | value="ws://localhost:5000", |
|
40 | 66 |
|
41 | 67 |
|
42 | 68 | def is_configured(): |
43 | | - if service_type == "Azure": |
44 | | - return azure_ai_services_api_key and azure_ai_speech_region |
45 | | - if service_type == "Local": |
46 | | - return host != "" |
| 69 | + if service_type == ServiceType.Azure.value: |
| 70 | + return azure_ai_speech_api_subscription_key and azure_ai_speech_region and speech_recognition_language |
| 71 | + if service_type == ServiceType.Local.value: |
| 72 | + return host != "" and speech_recognition_language |
47 | 73 | return False |
48 | 74 |
|
49 | 75 |
|
50 | 76 | def get_speech_config(): |
51 | | - if service_type == "Azure": |
| 77 | + if service_type == ServiceType.Azure.value: |
52 | 78 | return speechsdk.SpeechConfig( |
53 | | - subscription=azure_ai_services_api_key, |
| 79 | + subscription=azure_ai_speech_api_subscription_key, |
54 | 80 | region=azure_ai_speech_region, |
| 81 | + speech_recognition_language=speech_recognition_language, |
55 | 82 | ) |
56 | | - if service_type == "Local": |
| 83 | + if service_type == ServiceType.Local.value: |
57 | 84 | return speechsdk.SpeechConfig( |
58 | 85 | endpoint=host, |
| 86 | + speech_recognition_language=speech_recognition_language, |
59 | 87 | ) |
60 | 88 |
|
61 | 89 |
|
@@ -113,10 +141,12 @@ def conversation_transcriber_session_started_cb(evt: speechsdk.SessionEventArgs) |
113 | 141 |
|
114 | 142 | st.title("Azure AI Speech Services") |
115 | 143 |
|
| 144 | +# Show transcription status |
| 145 | +st.info(f"Transcription status: {st.session_state.transcription_status}") |
| 146 | + |
116 | 147 | if not is_configured(): |
117 | 148 | st.warning("Please fill in the required fields at the sidebar.") |
118 | 149 |
|
119 | | -st.info("Transcribe your speech.") |
120 | | - |
121 | 150 | if st.button("Transcribe", disabled=not is_configured()): |
| 151 | + st.session_state.transcription_status = TranscriptionStatus.InProgress |
122 | 152 | from_mic() |
0 commit comments