Skip to content

Commit 99cef41

Browse files
authored
Add Koltin and Java API for Kokoro TTS models (#1728)
1 parent 3a1de0b commit 99cef41

File tree

18 files changed

+549
-40
lines changed

18 files changed

+549
-40
lines changed

.github/workflows/run-java-test.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,12 @@ jobs:
234234
run: |
235235
cd ./java-api-examples
236236
237+
./run-non-streaming-tts-kokoro-en.sh
237238
./run-non-streaming-tts-matcha-zh.sh
238239
./run-non-streaming-tts-matcha-en.sh
240+
ls -lh
241+
242+
rm -rf kokoro-en-*
239243
240244
rm -rf matcha-icefall-*
241245
rm hifigan_v2.onnx

android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class MainActivity : AppCompatActivity() {
185185
var modelName: String?
186186
var acousticModelName: String?
187187
var vocoder: String?
188+
var voices: String?
188189
var ruleFsts: String?
189190
var ruleFars: String?
190191
var lexicon: String?
@@ -205,6 +206,10 @@ class MainActivity : AppCompatActivity() {
205206
vocoder = null
206207
// Matcha -- end
207208

209+
// For Kokoro -- begin
210+
voices = null
211+
// For Kokoro -- end
212+
208213

209214
modelDir = null
210215
ruleFsts = null
@@ -269,6 +274,13 @@ class MainActivity : AppCompatActivity() {
269274
// vocoder = "hifigan_v2.onnx"
270275
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"
271276

277+
// Example 9
278+
// kokoro-en-v0_19
279+
// modelDir = "kokoro-en-v0_19"
280+
// modelName = "model.onnx"
281+
// voices = "voices.bin"
282+
// dataDir = "kokoro-en-v0_19/espeak-ng-data"
283+
272284
if (dataDir != null) {
273285
val newDir = copyDataDir(dataDir!!)
274286
dataDir = "$newDir/$dataDir"
@@ -285,6 +297,7 @@ class MainActivity : AppCompatActivity() {
285297
modelName = modelName ?: "",
286298
acousticModelName = acousticModelName ?: "",
287299
vocoder = vocoder ?: "",
300+
voices = voices ?: "",
288301
lexicon = lexicon ?: "",
289302
dataDir = dataDir ?: "",
290303
dictDir = dictDir ?: "",

android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/GetSampleText.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ fun getSampleText(lang: String): String {
4747
}
4848

4949
"eng" -> {
50-
text = "This is a text-to-speech engine using next generation Kaldi"
50+
text = "How are you doing today? This is a text-to-speech engine using next generation Kaldi"
5151
}
5252

5353
"est" -> {

android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt

Lines changed: 185 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
package com.k2fsa.sherpa.onnx.tts.engine
44

55
import PreferenceHelper
6+
import android.media.AudioAttributes
7+
import android.media.AudioFormat
8+
import android.media.AudioManager
9+
import android.media.AudioTrack
610
import android.media.MediaPlayer
711
import android.net.Uri
812
import android.os.Bundle
@@ -36,7 +40,13 @@ import androidx.compose.ui.Modifier
3640
import androidx.compose.ui.text.input.KeyboardType
3741
import androidx.compose.ui.unit.dp
3842
import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme
43+
import kotlinx.coroutines.CoroutineScope
44+
import kotlinx.coroutines.Dispatchers
45+
import kotlinx.coroutines.channels.Channel
46+
import kotlinx.coroutines.launch
47+
import kotlinx.coroutines.withContext
3948
import java.io.File
49+
import kotlin.time.TimeSource
4050

4151
const val TAG = "sherpa-onnx-tts-engine"
4252

@@ -45,9 +55,26 @@ class MainActivity : ComponentActivity() {
4555
private val ttsViewModel: TtsViewModel by viewModels()
4656

4757
private var mediaPlayer: MediaPlayer? = null
58+
59+
// see
60+
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
61+
private lateinit var track: AudioTrack
62+
63+
private var stopped: Boolean = false
64+
65+
private var samplesChannel = Channel<FloatArray>()
66+
4867
override fun onCreate(savedInstanceState: Bundle?) {
4968
super.onCreate(savedInstanceState)
69+
70+
Log.i(TAG, "Start to initialize TTS")
5071
TtsEngine.createTts(this)
72+
Log.i(TAG, "Finish initializing TTS")
73+
74+
Log.i(TAG, "Start to initialize AudioTrack")
75+
initAudioTrack()
76+
Log.i(TAG, "Finish initializing AudioTrack")
77+
5178
val preferenceHelper = PreferenceHelper(this)
5279
setContent {
5380
SherpaOnnxTtsEngineTheme {
@@ -77,6 +104,11 @@ class MainActivity : ComponentActivity() {
77104
val testTextContent = getSampleText(TtsEngine.lang ?: "")
78105

79106
var testText by remember { mutableStateOf(testTextContent) }
107+
var startEnabled by remember { mutableStateOf(true) }
108+
var playEnabled by remember { mutableStateOf(false) }
109+
var rtfText by remember {
110+
mutableStateOf("")
111+
}
80112

81113
val numSpeakers = TtsEngine.tts!!.numSpeakers()
82114
if (numSpeakers > 1) {
@@ -119,52 +151,117 @@ class MainActivity : ComponentActivity() {
119151

120152
Row {
121153
Button(
122-
modifier = Modifier.padding(20.dp),
154+
enabled = startEnabled,
155+
modifier = Modifier.padding(5.dp),
123156
onClick = {
124157
Log.i(TAG, "Clicked, text: $testText")
125158
if (testText.isBlank() || testText.isEmpty()) {
126159
Toast.makeText(
127160
applicationContext,
128-
"Please input a test sentence",
161+
"Please input some text to generate",
129162
Toast.LENGTH_SHORT
130163
).show()
131164
} else {
132-
val audio = TtsEngine.tts!!.generate(
133-
text = testText,
134-
sid = TtsEngine.speakerId,
135-
speed = TtsEngine.speed,
136-
)
137-
138-
val filename =
139-
application.filesDir.absolutePath + "/generated.wav"
140-
val ok =
141-
audio.samples.isNotEmpty() && audio.save(
142-
filename
143-
)
165+
startEnabled = false
166+
playEnabled = false
167+
stopped = false
144168

145-
if (ok) {
146-
stopMediaPlayer()
147-
mediaPlayer = MediaPlayer.create(
148-
applicationContext,
149-
Uri.fromFile(File(filename))
150-
)
151-
mediaPlayer?.start()
152-
} else {
153-
Log.i(TAG, "Failed to generate or save audio")
169+
track.pause()
170+
track.flush()
171+
track.play()
172+
rtfText = ""
173+
Log.i(TAG, "Started with text $testText")
174+
175+
samplesChannel = Channel<FloatArray>()
176+
177+
CoroutineScope(Dispatchers.IO).launch {
178+
for (samples in samplesChannel) {
179+
track.write(
180+
samples,
181+
0,
182+
samples.size,
183+
AudioTrack.WRITE_BLOCKING
184+
)
185+
if (stopped) {
186+
break
187+
}
188+
}
154189
}
190+
191+
CoroutineScope(Dispatchers.Default).launch {
192+
val timeSource = TimeSource.Monotonic
193+
val startTime = timeSource.markNow()
194+
195+
val audio =
196+
TtsEngine.tts!!.generateWithCallback(
197+
text = testText,
198+
sid = TtsEngine.speakerId,
199+
speed = TtsEngine.speed,
200+
callback = ::callback,
201+
)
202+
203+
val elapsed =
204+
startTime.elapsedNow().inWholeMilliseconds.toFloat() / 1000;
205+
val audioDuration =
206+
audio.samples.size / TtsEngine.tts!!.sampleRate()
207+
.toFloat()
208+
val RTF = String.format(
209+
"Number of threads: %d\nElapsed: %.3f s\nAudio duration: %.3f s\nRTF: %.3f/%.3f = %.3f",
210+
TtsEngine.tts!!.config.model.numThreads,
211+
audioDuration,
212+
elapsed,
213+
elapsed,
214+
audioDuration,
215+
elapsed / audioDuration
216+
)
217+
samplesChannel.close()
218+
219+
val filename =
220+
application.filesDir.absolutePath + "/generated.wav"
221+
222+
223+
val ok =
224+
audio.samples.isNotEmpty() && audio.save(
225+
filename
226+
)
227+
228+
if (ok) {
229+
withContext(Dispatchers.Main) {
230+
startEnabled = true
231+
playEnabled = true
232+
rtfText = RTF
233+
}
234+
}
235+
}.start()
155236
}
156237
}) {
157-
Text("Test")
238+
Text("Start")
158239
}
159240

160241
Button(
161-
modifier = Modifier.padding(20.dp),
242+
modifier = Modifier.padding(5.dp),
243+
enabled = playEnabled,
162244
onClick = {
163-
TtsEngine.speakerId = 0
164-
TtsEngine.speed = 1.0f
165-
testText = ""
245+
stopped = true
246+
track.pause()
247+
track.flush()
248+
onClickPlay()
166249
}) {
167-
Text("Reset")
250+
Text("Play")
251+
}
252+
253+
Button(
254+
modifier = Modifier.padding(5.dp),
255+
onClick = {
256+
onClickStop()
257+
startEnabled = true
258+
}) {
259+
Text("Stop")
260+
}
261+
}
262+
if (rtfText.isNotEmpty()) {
263+
Row {
264+
Text(rtfText)
168265
}
169266
}
170267
}
@@ -185,4 +282,63 @@ class MainActivity : ComponentActivity() {
185282
mediaPlayer?.release()
186283
mediaPlayer = null
187284
}
285+
286+
private fun onClickPlay() {
287+
val filename = application.filesDir.absolutePath + "/generated.wav"
288+
stopMediaPlayer()
289+
mediaPlayer = MediaPlayer.create(
290+
applicationContext,
291+
Uri.fromFile(File(filename))
292+
)
293+
mediaPlayer?.start()
294+
}
295+
296+
private fun onClickStop() {
297+
stopped = true
298+
track.pause()
299+
track.flush()
300+
301+
stopMediaPlayer()
302+
}
303+
304+
// this function is called from C++
305+
private fun callback(samples: FloatArray): Int {
306+
if (!stopped) {
307+
val samplesCopy = samples.copyOf()
308+
CoroutineScope(Dispatchers.IO).launch {
309+
samplesChannel.send(samplesCopy)
310+
}
311+
return 1
312+
} else {
313+
track.stop()
314+
Log.i(TAG, " return 0")
315+
return 0
316+
}
317+
}
318+
319+
private fun initAudioTrack() {
320+
val sampleRate = TtsEngine.tts!!.sampleRate()
321+
val bufLength = AudioTrack.getMinBufferSize(
322+
sampleRate,
323+
AudioFormat.CHANNEL_OUT_MONO,
324+
AudioFormat.ENCODING_PCM_FLOAT
325+
)
326+
Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength")
327+
328+
val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
329+
.setUsage(AudioAttributes.USAGE_MEDIA)
330+
.build()
331+
332+
val format = AudioFormat.Builder()
333+
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
334+
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
335+
.setSampleRate(sampleRate)
336+
.build()
337+
338+
track = AudioTrack(
339+
attr, format, bufLength, AudioTrack.MODE_STREAM,
340+
AudioManager.AUDIO_SESSION_ID_GENERATE
341+
)
342+
track.play()
343+
}
188344
}

android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ object TtsEngine {
4141

4242
private var modelDir: String? = null
4343
private var modelName: String? = null
44-
private var acousticModelName: String? = null
45-
private var vocoder: String? = null
44+
private var acousticModelName: String? = null // for matcha tts
45+
private var vocoder: String? = null // for matcha tts
46+
private var voices: String? = null // for kokoro
4647
private var ruleFsts: String? = null
4748
private var ruleFars: String? = null
4849
private var lexicon: String? = null
@@ -64,6 +65,10 @@ object TtsEngine {
6465
vocoder = null
6566
// For Matcha -- end
6667

68+
// For Kokoro -- begin
69+
voices = null
70+
// For Kokoro -- end
71+
6772
modelDir = null
6873
ruleFsts = null
6974
ruleFars = null
@@ -139,6 +144,14 @@ object TtsEngine {
139144
// vocoder = "hifigan_v2.onnx"
140145
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"
141146
// lang = "eng"
147+
148+
// Example 9
149+
// kokoro-en-v0_19
150+
// modelDir = "kokoro-en-v0_19"
151+
// modelName = "model.onnx"
152+
// voices = "voices.bin"
153+
// dataDir = "kokoro-en-v0_19/espeak-ng-data"
154+
// lang = "eng"
142155
}
143156

144157
fun createTts(context: Context) {
@@ -167,6 +180,7 @@ object TtsEngine {
167180
modelName = modelName ?: "",
168181
acousticModelName = acousticModelName ?: "",
169182
vocoder = vocoder ?: "",
183+
voices = voices ?: "",
170184
lexicon = lexicon ?: "",
171185
dataDir = dataDir ?: "",
172186
dictDir = dictDir ?: "",
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
<resources>
2-
<string name="app_name">TTS Engine</string>
2+
<string name="app_name">TTS Engine: Next-gen Kaldi</string>
33
</resources>

0 commit comments

Comments
 (0)