|
1 | 1 | import time
|
| 2 | +from enum import Enum |
2 | 3 | from os import getenv
|
3 | 4 |
|
4 | 5 | import azure.cognitiveservices.speech as speechsdk
|
|
7 | 8 |
|
8 | 9 | load_dotenv()
|
9 | 10 |
|
| 11 | + |
| 12 | +class ServiceType(Enum): |
| 13 | + Azure = "Azure" |
| 14 | + Local = "Local" |
| 15 | + |
| 16 | + |
| 17 | +class TranscriptionStatus(Enum): |
| 18 | + NotStarted = "Not started" |
| 19 | + InProgress = "In progress" |
| 20 | + Completed = "Completed" |
| 21 | + |
| 22 | + |
| 23 | +if "transcription_status" not in st.session_state: |
| 24 | + st.session_state.transcription_status = TranscriptionStatus.NotStarted |
| 25 | + |
| 26 | + |
10 | 27 | with st.sidebar:
|
| 28 | + speech_recognition_language = st.selectbox( |
| 29 | + label="Speech recognition language", |
| 30 | + options=[ |
| 31 | + "en-US", |
| 32 | + "ja-JP", |
| 33 | + "zh-CN", |
| 34 | + ], |
| 35 | + ) |
11 | 36 | service_type = st.selectbox(
|
12 | 37 | label="Service type",
|
13 | 38 | options=[
|
14 |
| - "Local", |
15 |
| - "Azure", |
| 39 | + ServiceType.Local.value, |
| 40 | + ServiceType.Azure.value, |
16 | 41 | ],
|
17 | 42 | )
|
18 |
| - if service_type == "Azure": |
19 |
| - azure_ai_services_api_key = st.text_input( |
20 |
| - label="AZURE_AI_SERVICES_API_KEY", |
21 |
| - key="AZURE_AI_SERVICES_API_KEY", |
| 43 | + if service_type == ServiceType.Azure.value: |
| 44 | + azure_ai_speech_api_subscription_key = st.text_input( |
| 45 | + label="AZURE_AI_SPEECH_API_SUBSCRIPTION_KEY", |
| 46 | + value=getenv("AZURE_AI_SPEECH_API_SUBSCRIPTION_KEY"), |
| 47 | + key="AZURE_AI_SPEECH_API_SUBSCRIPTION_KEY", |
22 | 48 | type="password",
|
23 | 49 | )
|
24 | 50 | azure_ai_speech_region = st.text_input(
|
|
27 | 53 | key="AZURE_AI_SPEECH_REGION",
|
28 | 54 | type="default",
|
29 | 55 | )
|
30 |
| - if service_type == "Local": |
| 56 | + if service_type == ServiceType.Local.value: |
31 | 57 | host = st.text_input(
|
32 | 58 | label="Host",
|
33 | 59 | value="ws://localhost:5000",
|
|
40 | 66 |
|
41 | 67 |
|
42 | 68 | def is_configured():
|
43 |
| - if service_type == "Azure": |
44 |
| - return azure_ai_services_api_key and azure_ai_speech_region |
45 |
| - if service_type == "Local": |
46 |
| - return host != "" |
| 69 | + if service_type == ServiceType.Azure.value: |
| 70 | + return azure_ai_speech_api_subscription_key and azure_ai_speech_region and speech_recognition_language |
| 71 | + if service_type == ServiceType.Local.value: |
| 72 | + return host != "" and speech_recognition_language |
47 | 73 | return False
|
48 | 74 |
|
49 | 75 |
|
50 | 76 | def get_speech_config():
|
51 |
| - if service_type == "Azure": |
| 77 | + if service_type == ServiceType.Azure.value: |
52 | 78 | return speechsdk.SpeechConfig(
|
53 |
| - subscription=azure_ai_services_api_key, |
| 79 | + subscription=azure_ai_speech_api_subscription_key, |
54 | 80 | region=azure_ai_speech_region,
|
| 81 | + speech_recognition_language=speech_recognition_language, |
55 | 82 | )
|
56 |
| - if service_type == "Local": |
| 83 | + if service_type == ServiceType.Local.value: |
57 | 84 | return speechsdk.SpeechConfig(
|
58 | 85 | endpoint=host,
|
| 86 | + speech_recognition_language=speech_recognition_language, |
59 | 87 | )
|
60 | 88 |
|
61 | 89 |
|
@@ -113,10 +141,12 @@ def conversation_transcriber_session_started_cb(evt: speechsdk.SessionEventArgs)
|
113 | 141 |
|
114 | 142 | st.title("Azure AI Speech Services")
|
115 | 143 |
|
| 144 | +# Show transcription status |
| 145 | +st.info(f"Transcription status: {st.session_state.transcription_status}") |
| 146 | + |
116 | 147 | if not is_configured():
|
117 | 148 | st.warning("Please fill in the required fields at the sidebar.")
|
118 | 149 |
|
119 |
| -st.info("Transcribe your speech.") |
120 |
| - |
121 | 150 | if st.button("Transcribe", disabled=not is_configured()):
|
| 151 | + st.session_state.transcription_status = TranscriptionStatus.InProgress |
122 | 152 | from_mic()
|
0 commit comments