Skip to content

Commit 57a6ef1

Browse files
committed
add ViTS Support
1 parent 6e02124 commit 57a6ef1

File tree

13 files changed

+614
-174
lines changed

13 files changed

+614
-174
lines changed

lib/global.dart

Lines changed: 135 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
import 'package:arabic_learning/statics_var.dart';
2+
import 'package:archive/archive.dart';
3+
import 'package:dio/dio.dart';
4+
import 'package:flutter/foundation.dart';
25
import 'package:flutter_tts/flutter_tts.dart';
36
import 'package:http/http.dart' as http;
7+
import 'package:just_audio/just_audio.dart' show AudioSource;
8+
import 'package:path_provider/path_provider.dart' as path_provider;
9+
import 'package:provider/provider.dart';
410
import 'package:shared_preferences/shared_preferences.dart';
511
import 'dart:convert';
612
import 'package:flutter/material.dart';
713
import 'package:google_fonts/google_fonts.dart';
8-
14+
import 'package_replacement/nonsense_hook_io.dart' if (dart.library.io) 'dart:io' as io;
15+
import 'package_replacement/nonsense_hook.dart' if (dart.library.io) 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
916

1017
class Global with ChangeNotifier {
1118
late bool firstStart;
1219
late bool isWideScreen;
1320
late final SharedPreferences prefs;
1421
late String dailyWord = "";
22+
late bool modelTTSDownloaded = false;
1523
Map<String, dynamic> _settingData = {
1624
'User': "",
1725
'regular': {
@@ -20,7 +28,7 @@ class Global with ChangeNotifier {
2028
"darkMode": false,
2129
},
2230
'audio': {
23-
"useBackupSource": false,
31+
"useBackupSource": 0, // 0: Normal, 1: OnlineBackup, 2: LocalVITS
2432
"playRate": 1.0,
2533
},
2634
'learning': {
@@ -64,10 +72,42 @@ class Global with ChangeNotifier {
6472
);
6573

6674
late Map<String, dynamic> wordData = {};
75+
late sherpa_onnx.OfflineTts vitsTTS;
6776
ThemeData get themeData => _themeData;
6877
Map<String, dynamic> get settingData => _settingData;
6978
int get wordCount => wordData["Words"]!.length;
7079

80+
// load TTS model if any
81+
Future<void> loadTTS() async {
82+
if(kIsWeb) return;
83+
final basePath = await path_provider.getApplicationDocumentsDirectory();
84+
if(io.File("${basePath.path}/${StaticsVar.modelPath}/ar_JO-kareem-medium.onnx").existsSync()){
85+
modelTTSDownloaded = true;
86+
sherpa_onnx.initBindings();
87+
final vits = sherpa_onnx.OfflineTtsVitsModelConfig(
88+
model: "${basePath.path}/${StaticsVar.modelPath}/ar_JO-kareem-medium.onnx",
89+
// lexicon: '${basePath.path}/${StaticsVar.modelPath}/',
90+
dataDir: "${basePath.path}/${StaticsVar.modelPath}/espeak-ng-data",
91+
tokens: '${basePath.path}/${StaticsVar.modelPath}/tokens.txt',
92+
lengthScale: 1 / _settingData["audio"]["playRate"],
93+
);
94+
// kokoro = sherpa_onnx.OfflineTtsKokoroModelConfig();
95+
final modelConfig = sherpa_onnx.OfflineTtsModelConfig(
96+
vits: vits,
97+
numThreads: 2,
98+
debug: false,
99+
provider: 'cpu',
100+
);
101+
102+
final config = sherpa_onnx.OfflineTtsConfig(
103+
model: modelConfig,
104+
maxNumSenetences: 1,
105+
);
106+
107+
vitsTTS = sherpa_onnx.OfflineTts(config);
108+
}
109+
}
110+
71111
Map<K, V> deepMerge<K, V>(Map<K, V> base, Map<K, V> overlay) {
72112
final result = Map<K, V>.from(base);
73113
overlay.forEach((key, value) {
@@ -83,6 +123,22 @@ class Global with ChangeNotifier {
83123
return result;
84124
}
85125

126+
void conveySetting() {
127+
Map<String, dynamic> oldSetting = jsonDecode(prefs.getString("settingData")!) as Map<String, dynamic>;
128+
129+
// For v0.1.5 upgrade
130+
if(oldSetting["audio"]["useBackupSource"].runtimeType == bool) {
131+
if(oldSetting["audio"]["useBackupSource"]) {
132+
oldSetting["audio"]["useBackupSource"] = 1;
133+
} else {
134+
oldSetting["audio"]["useBackupSource"] = 0;
135+
}
136+
}
137+
138+
139+
_settingData = deepMerge(_settingData, oldSetting);
140+
}
141+
86142
Future<void> init() async {
87143
prefs = await SharedPreferences.getInstance();
88144
firstStart = prefs.getString("settingData") == null;
@@ -92,8 +148,9 @@ class Global with ChangeNotifier {
92148
} else {
93149
wordData = jsonDecode(prefs.getString("wordData")!) as Map<String, dynamic>;
94150
}
151+
await loadTTS();
95152
if (firstStart) return;
96-
_settingData = deepMerge(_settingData, jsonDecode(prefs.getString("settingData")!) as Map<String, dynamic>);
153+
conveySetting();
97154
}
98155
void updateTheme() {
99156
_themeData = ThemeData(
@@ -107,14 +164,14 @@ class Global with ChangeNotifier {
107164
notifyListeners();
108165
}
109166

110-
Future<void> acceptAggrement(String name) async {
167+
void acceptAggrement(String name) {
111168
firstStart = false;
112169
_settingData["User"] = name;
113170
prefs.setString("settingData", jsonEncode(settingData));
114171
notifyListeners();
115172
}
116173

117-
Future<void> updateSetting(Map<String, dynamic> settingData) async {
174+
void updateSetting(Map<String, dynamic> settingData) {
118175
_settingData = settingData;
119176
try {
120177
prefs.setString("settingData", jsonEncode(settingData));
@@ -184,7 +241,7 @@ class Global with ChangeNotifier {
184241
return exData;
185242
}
186243

187-
void importData(Map<String, dynamic> data, String source) async {
244+
void importData(Map<String, dynamic> data, String source) {
188245
wordData = dataFormater(data, wordData, source);
189246
prefs.setString("wordData", jsonEncode(wordData));
190247
notifyListeners();
@@ -246,9 +303,19 @@ class InDevelopingPage extends StatelessWidget {
246303
}
247304
}
248305

249-
Future<List<dynamic>> playTextToSpeech(String text, {bool useBackup = false, double playRate = 1.0}) async {
306+
Future<List<dynamic>> playTextToSpeech(String text, BuildContext context) async {
250307
// return [bool isSuccessed?, String errorInfo];
251-
if (useBackup) {
308+
// 0: System TTS
309+
if (context.read<Global>().settingData["audio"]["useBackupSource"] == 0) {
310+
FlutterTts flutterTts = FlutterTts();
311+
if(!(await flutterTts.getLanguages).toString().contains("ar")) return [false, "你的设备似乎未安装阿拉伯语语言或不支持阿拉伯语文本转语音功能,语音可能无法正常播放。\n你可以尝试在 设置 - 系统语言 - 添加语言 中添加阿拉伯语。\n实在无法使用可在设置页面启用备用音频源(需要网络)"];
312+
await flutterTts.setLanguage("ar");
313+
await flutterTts.setPitch(1.0);
314+
if(!context.mounted) return [false, "神经网络音频合成失败\n中途退出context"];
315+
await flutterTts.setSpeechRate(context.read<Global>().settingData["audio"]["playRate"] / 2);
316+
await flutterTts.speak(text);
317+
// 1: TextReadTTS
318+
} else if (context.read<Global>().settingData["audio"]["useBackupSource"] == 1) {
252319
try {
253320
final response = await http.get(Uri.parse("https://textreadtts.com/tts/convert?accessKey=FREE&language=arabic&speaker=speaker2&text=$text")).timeout(Duration(seconds: 8), onTimeout: () => throw Exception("请求超时"));
254321
if (response.statusCode == 200) {
@@ -257,21 +324,41 @@ Future<List<dynamic>> playTextToSpeech(String text, {bool useBackup = false, dou
257324
return [false, "备用音源请求失败:\n错误信息:文本长度超过API限制"];
258325
}
259326
await StaticsVar.player.setUrl(data["audio"]);
260-
await StaticsVar.player.setSpeed(playRate);
327+
if(!context.mounted) return [false, "神经网络音频合成失败\n中途退出context"];
328+
await StaticsVar.player.setSpeed(context.read<Global>().settingData["audio"]["playRate"]);
261329
await StaticsVar.player.play();
262330
} else {
263331
return [false, "备用音源请求失败:\n错误码:${response.statusCode.toString()}"];
264332
}
265333
} catch (e) {
266334
return [false, "备用音源请求失败:\n错误信息:${e.toString()}"];
267335
}
268-
} else {
269-
FlutterTts flutterTts = FlutterTts();
270-
if(!(await flutterTts.getLanguages).toString().contains("ar")) return [false, "你的设备似乎未安装阿拉伯语语言或不支持阿拉伯语文本转语音功能,语音可能无法正常播放。\n你可以尝试在 设置 - 系统语言 - 添加语言 中添加阿拉伯语。\n实在无法使用可在设置页面启用备用音频源(需要网络)"];
271-
await flutterTts.setLanguage("ar");
272-
await flutterTts.setPitch(1.0);
273-
await flutterTts.setSpeechRate(playRate / 2);
274-
await flutterTts.speak(text);
336+
337+
// 2: sherpa-onnx
338+
} else if (context.read<Global>().settingData["audio"]["useBackupSource"] == 2) {
339+
try {
340+
final basePath = await path_provider.getApplicationCacheDirectory();
341+
final cacheFile = io.File("${basePath.path}/temp.wav");
342+
if(cacheFile.existsSync()) cacheFile.deleteSync();
343+
if(!context.mounted) return [false, "神经网络音频合成失败\n中途退出context"];
344+
final audio = context.read<Global>().vitsTTS.generate(text: text, speed: context.read<Global>().settingData["audio"]["playRate"]);
345+
final ok = sherpa_onnx.writeWave(
346+
filename: cacheFile.path,
347+
samples: audio.samples,
348+
sampleRate: audio.sampleRate,
349+
);
350+
if(ok) {
351+
await StaticsVar.player.setAudioSource(AudioSource.uri(Uri.file(cacheFile.path)));
352+
// await StaticsVar.player.setSpeed(playRate);
353+
await StaticsVar.player.play();
354+
await Future.delayed(Duration(milliseconds: 1000));
355+
if(cacheFile.existsSync()) cacheFile.deleteSync();
356+
}else {
357+
return [false, "神经网络音频合成失败\n错误信息:无法将音频写入文件"];
358+
}
359+
} catch (e) {
360+
return [false, "神经网络音频合成失败\n错误信息:${e.toString()}"];
361+
}
275362
}
276363
return [true, ""];
277364
}
@@ -318,4 +405,36 @@ class TextContainer extends StatelessWidget {
318405
child: (selectable??false) ? SelectableText(text,style: (style == null) ? TextStyle(fontSize: 18.0) : style) : Text(text,style: (style == null) ? TextStyle(fontSize: 18.0) : style),
319406
);
320407
}
408+
}
409+
410+
Future<void> extractTarBz2(String inputPath, String outputDir) async {
411+
final bytes = await io.File(inputPath).readAsBytes();
412+
413+
// 解压 bz2
414+
final bz2Decoder = BZip2Decoder();
415+
final tarBytes = bz2Decoder.decodeBytes(bytes);
416+
417+
// 解包 tar
418+
final tarArchive = TarDecoder().decodeBytes(tarBytes);
419+
420+
// 解出文件
421+
for (final file in tarArchive.files) {
422+
final filePath = '$outputDir/${io.Platform.pathSeparator}${file.name}';
423+
if (file.isFile) {
424+
final outFile = io.File(filePath);
425+
await outFile.create(recursive: true);
426+
await outFile.writeAsBytes(file.content as List<int>);
427+
} else {
428+
await io.Directory(filePath).create(recursive: true);
429+
}
430+
}
431+
}
432+
433+
Future<void> downloadFile(String url, String savePath, {ProgressCallback? onDownloading}) async {
434+
final dio = Dio();
435+
await dio.download(
436+
url,
437+
savePath,
438+
onReceiveProgress: onDownloading?? (count, total){},
439+
);
321440
}

lib/home_page.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class HomePage extends StatelessWidget {
2222
if(context.read<Global>().dailyWord.isNotEmpty) {
2323
playing = true;
2424
late List<dynamic> temp;
25-
temp = await playTextToSpeech(context.read<Global>().dailyWord, useBackup: context.read<Global>().settingData['audio']["useBackupSource"], playRate: context.read<Global>().settingData['audio']["playRate"]);
25+
temp = await playTextToSpeech(context.read<Global>().dailyWord, context);
2626
if(!temp[0] && context.mounted) {
2727
alart(context, temp[1]);
2828
}

lib/learning_pages_build.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ List<Widget> questionConstructer(BuildContext context, int index, List<String> d
7676
playing = true;
7777
});
7878
late List<dynamic> temp;
79-
temp = await playTextToSpeech(data[0], useBackup: context.read<Global>().settingData['audio']["useBackupSource"], playRate: context.read<Global>().settingData['audio']["playRate"]);
79+
temp = await playTextToSpeech(data[0], context);
8080
if(!temp[0] && context.mounted) {
8181
alart(context, temp[1]);
8282
}

lib/main.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class _MyHomePageState extends State<MyHomePage> {
247247
),
248248
onPressed: () async {
249249
if(controller.text.isNotEmpty){
250-
await context.read<Global>().acceptAggrement(controller.text);
250+
context.read<Global>().acceptAggrement(controller.text);
251251
} else {
252252
ScaffoldMessenger.of(context).showSnackBar(
253253
const SnackBar(
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
2+
3+
class OfflineTts {
4+
OfflineTts(OfflineTtsConfig config);
5+
6+
generate({required String text, required speed}) {}
7+
}
8+
9+
void initBindings ({
10+
int? opt = 0
11+
}) {
12+
return;
13+
}
14+
15+
class OfflineTtsVitsModelConfig {
16+
String model;
17+
String dataDir;
18+
String tokens;
19+
double lengthScale;
20+
OfflineTtsVitsModelConfig(
21+
{required this.model,
22+
required this.dataDir,
23+
required this.tokens,
24+
required this.lengthScale});
25+
}
26+
27+
class OfflineTtsModelConfig {
28+
OfflineTtsVitsModelConfig vits;
29+
int numThreads;
30+
bool debug;
31+
String provider;
32+
OfflineTtsModelConfig(
33+
{required this.vits,
34+
required this.numThreads,
35+
required this.debug,
36+
required this.provider});
37+
}
38+
39+
class OfflineTtsConfig {
40+
OfflineTtsModelConfig model;
41+
int maxNumSenetences;
42+
OfflineTtsConfig(
43+
{required this.model, required this.maxNumSenetences});
44+
}
45+
46+
bool writeWave(
47+
{required String filename,
48+
required dynamic samples,
49+
required int sampleRate}) {
50+
return false;
51+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
2+
class File {
3+
File(String filePath);
4+
String get path {
5+
// nothing
6+
return "";
7+
}
8+
bool existsSync() {return true;}
9+
10+
void deleteSync() {}
11+
12+
Future<void> create({required bool recursive}) async {}
13+
14+
Future<void> writeAsBytes(List<int> content) async {}
15+
16+
Future readAsBytes() async {}
17+
18+
Future<String> readAsString() async {return "";}
19+
20+
void delete() {}
21+
}
22+
23+
class Directory {
24+
Directory(String filePath);
25+
26+
Future<void> create({required bool recursive}) async {}
27+
}
28+
29+
class Platform{
30+
static const String pathSeparator = "";
31+
32+
static bool get isWindows => false;
33+
34+
static bool get isLinux => false;
35+
36+
static bool get isMacOS => false;
37+
}

0 commit comments

Comments
 (0)