Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';

import 'model_type_test.dart';
import 'vad_iterator.dart';

Expand All @@ -22,6 +23,7 @@ class _MyAppState extends State<MyApp> {
VadIterator? _vadIterator;
static const frameSize = 64;
static const sampleRate = 16000;
String? costTime;

@override
void initState() {
Expand All @@ -34,6 +36,7 @@ class _MyAppState extends State<MyApp> {
@override
Widget build(BuildContext context) {
const textStyle = TextStyle(fontSize: 16);

return MaterialApp(
theme: ThemeData(useMaterial3: true),
home: Scaffold(
Expand All @@ -56,26 +59,31 @@ class _MyAppState extends State<MyApp> {
height: 50,
),
TextButton(
onPressed: () {
_typeTest();
},
child: const Text('Mode Type Test')),
onPressed: _typeTest,
child: const Text('Mode Type Test'),
),
const SizedBox(
height: 50,
),
TextButton(
onPressed: () {
_vad(false);
},
child: const Text('VAD')),
onPressed: () => _vad(false),
child: const Text('VAD'),
),
const SizedBox(
height: 50,
),
TextButton(
onPressed: () {
_vad(true);
},
child: const Text('VAD Concurrency')),
onPressed: () => _vad(true),
child: const Text('VAD Concurrency'),
),
const SizedBox(
height: 50,
),
Text(
costTime ?? '',
style: textStyle,
textAlign: TextAlign.center,
),
],
),
),
Expand All @@ -85,33 +93,43 @@ class _MyAppState extends State<MyApp> {
}

_typeTest() async {
setState(() => costTime = null);
final startTime = DateTime.now().millisecondsSinceEpoch;
List<OrtValue?>? outputs;
outputs = await ModelTypeTest.testBool();

print('out=${outputs[0]?.value}');
outputs.forEach((element) {
element?.release();
});

outputs = await ModelTypeTest.testFloat();
print('out=${outputs[0]?.value}');
outputs.forEach((element) {
element?.release();
});

outputs = await ModelTypeTest.testInt64();
print('out=${outputs[0]?.value}');
outputs.forEach((element) {
element?.release();
});

outputs = await ModelTypeTest.testString();
print('out=${outputs[0]?.value}');
outputs.forEach((element) {
element?.release();
});

final endTime = DateTime.now().millisecondsSinceEpoch;
print('infer cost time=${endTime - startTime}ms');
String inferCost = 'Infer cost time=${endTime - startTime}ms';
print(inferCost);
setState(() => costTime = inferCost);
}

_vad(bool concurrent) async {
setState(() => costTime = null);

const windowByteCount = frameSize * 2 * sampleRate ~/ 1000;
final rawAssetFile = await rootBundle.load('assets/audio/vad_example.pcm');
final bytes = rawAssetFile.buffer.asUint8List();
Expand All @@ -130,7 +148,10 @@ class _MyAppState extends State<MyApp> {
}
_vadIterator?.reset();
final endTime = DateTime.now().millisecondsSinceEpoch;
print('vad cost time=${endTime - startTime}ms');

String costTimeTemp = 'Vad cost time=${endTime - startTime}ms';
print(costTimeTemp);
setState(() => costTime = costTimeTemp);
}

Int16List _transformBuffer(List<int> buffer) {
Expand Down
42 changes: 30 additions & 12 deletions example/lib/vad_iterator.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class VadIterator {
final _speechPadMs = 0;
late int _minSilenceSamples;
late int _speechPadSamples;

/// support 256 512 768 for 8k; 512 1024 1536 for 16k
late int _windowSizeSamples;

Expand All @@ -24,9 +25,12 @@ class VadIterator {
var _currentSample = 0;

static const int _batch = 1;

/// model inputs
var _hide = List.filled(2, List.filled(_batch, Float32List.fromList(List.filled(64, 0.0))));
var _cell = List.filled(2, List.filled(_batch, Float32List.fromList(List.filled(64, 0.0))));
var _hide = List.filled(
2, List.filled(_batch, Float32List.fromList(List.filled(64, 0.0))));
var _cell = List.filled(
2, List.filled(_batch, Float32List.fromList(List.filled(64, 0.0))));

VadIterator(this._frameSize, this._sampleRate) {
final srPerMs = _sampleRate ~/ 1000;
Expand All @@ -43,8 +47,10 @@ class VadIterator {
_triggered = false;
_tempEnd = 0;
_currentSample = 0;
_hide = List.filled(2, List.filled(_batch, Float32List.fromList(List.filled(64, 0.0))));
_cell = List.filled(2, List.filled(_batch, Float32List.fromList(List.filled(64, 0.0))));
_hide = List.filled(
2, List.filled(_batch, Float32List.fromList(List.filled(64, 0.0))));
_cell = List.filled(
2, List.filled(_batch, Float32List.fromList(List.filled(64, 0.0))));
}

release() {
Expand All @@ -61,14 +67,14 @@ class VadIterator {
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
const assetFileName = 'assets/models/silero_vad.onnx';
final rawAssetFile = await rootBundle.load(assetFileName);
final bytes = rawAssetFile.buffer.asUint8List();
ByteData rawAssetFile = await rootBundle.load(assetFileName);
Uint8List bytes = rawAssetFile.buffer.asUint8List();
_session = OrtSession.fromBuffer(bytes, _sessionOptions!);
}

Future<bool> predict(Float32List data, bool concurrent) async {
final inputOrt =
OrtValueTensor.createTensorWithDataList(data, [_batch, _windowSizeSamples]);
final inputOrt = OrtValueTensor.createTensorWithDataList(
data, [_batch, _windowSizeSamples]);
final srOrt = OrtValueTensor.createTensorWithData(_sampleRate);
final hOrt = OrtValueTensor.createTensorWithDataList(_hide);
final cOrt = OrtValueTensor.createTensorWithDataList(_cell);
Expand All @@ -85,13 +91,19 @@ class VadIterator {
hOrt.release();
cOrt.release();
runOptions.release();

/// Output probability & update h,c recursively
final output = (outputs?[0]?.value as List<List<double>>)[0][0];
_hide = (outputs?[1]?.value as List<List<List<double>>>).map((e) => e.map((e) => Float32List.fromList(e)).toList()).toList();
_cell = (outputs?[2]?.value as List<List<List<double>>>).map((e) => e.map((e) => Float32List.fromList(e)).toList()).toList();
_hide = (outputs?[1]?.value as List<List<List<double>>>)
.map((e) => e.map((e) => Float32List.fromList(e)).toList())
.toList();
_cell = (outputs?[2]?.value as List<List<List<double>>>)
.map((e) => e.map((e) => Float32List.fromList(e)).toList())
.toList();
outputs?.forEach((element) {
element?.release();
});

/// Push forward sample index
_currentSample += _windowSizeSamples;

Expand All @@ -113,8 +125,10 @@ class VadIterator {
/// 3) Start
if (output >= _threshold && !_triggered) {
_triggered = true;

/// minus window_size_samples to get precise start time point.
final speechStart = _currentSample - _windowSizeSamples - _speechPadSamples;
final speechStart =
_currentSample - _windowSizeSamples - _speechPadSamples;
print('vad start: ${speechStart / _sampleRate}s');
}

Expand All @@ -123,13 +137,17 @@ class VadIterator {
if (_tempEnd == 0) {
_tempEnd = _currentSample;
}

/// a. silence < min_slience_samples, continue speaking
if (_currentSample - _tempEnd < _minSilenceSamples) {
print('vad speaking4: ${_currentSample / _sampleRate}s');
}

/// b. silence >= min_slience_samples, end speaking
else {
final speechEnd = _tempEnd > 0 ? _tempEnd + _speechPadSamples : _currentSample + _speechPadSamples;
final speechEnd = _tempEnd > 0
? _tempEnd + _speechPadSamples
: _currentSample + _speechPadSamples;
_tempEnd = 0;
_triggered = false;
print('vad end: ${speechEnd / _sampleRate}s');
Expand Down
Loading