@@ -371,9 +371,8 @@ class TfLiteAudioFeatureProvider {
371
371
int kFeatureSliceStrideMs = 20 ;
372
372
int kFeatureSliceDurationMs = 30 ;
373
373
374
- // Variables for the model's output categories.
375
- int kSilenceIndex = 0 ;
376
- int kUnknownIndex = 1 ;
374
+ // number of new slices to collect before evaluating the model
375
+ int kSlicesToProcess = 2 ;
377
376
378
377
// Callback method for result
379
378
void (*respondToCommand)(const char * found_command, uint8_t score,
@@ -535,11 +534,22 @@ class TfLiteAudioOutput : public AudioPrint {
535
534
public:
536
535
TfLiteAudioOutput () {}
537
536
538
- // The name of this function is important for Arduino compatibility.
537
+ // / Optionally define your own recognizer
538
+ void setRecognizer (RecognizeCommands<N>* p_recognizer) {
539
+ recognizer = p_recognizer;
540
+ }
541
+
542
+ // / Optionally define your own interpreter
543
+ void setInterpreter (tflite::MicroInterpreter* p_interpreter) {
544
+ this ->interpreter = p_interpreter;
545
+ }
546
+
547
+ // / Start the processing
539
548
virtual bool begin (const unsigned char * model,
540
549
TfLiteAudioFeatureProvider& featureProvider,
541
- const char ** labels, int tensorArenaSize = 10 * 1024 ) {
550
+ const char ** labels, int tensorArenaSize = 10 * 1024 , bool all_ops_resolver= false ) {
542
551
LOGD (LOG_METHOD);
552
+ this ->use_all_ops_resolver = all_ops_resolver;
543
553
this ->kTensorArenaSize = tensorArenaSize;
544
554
545
555
// setup the feature provider
@@ -552,12 +562,15 @@ class TfLiteAudioOutput : public AudioPrint {
552
562
553
563
// Map the model into a usable data structure. This doesn't involve any
554
564
// copying or parsing, it's a very lightweight operation.
555
- if (!setupModel (model)) {
565
+ if (!setModel (model)) {
556
566
return false ;
557
567
}
558
568
559
- if (!setupInterpreter ()) {
560
- return false ;
569
+ // setup default interpreter if not assigned yet
570
+ if (interpreter == nullptr ) {
571
+ if (!setupInterpreter ()) {
572
+ return false ;
573
+ }
561
574
}
562
575
563
576
// Allocate memory from the tensor_arena for the model's tensors.
@@ -586,9 +599,12 @@ class TfLiteAudioOutput : public AudioPrint {
586
599
return false ;
587
600
}
588
601
589
- static RecognizeCommands<N> static_recognizer;
590
- recognizer = &static_recognizer;
591
- recognizer->setLabels (labels);
602
+ // setup default recognizer if not defined
603
+ if (recognizer == nullptr ) {
604
+ static RecognizeCommands<N> static_recognizer;
605
+ recognizer = &static_recognizer;
606
+ recognizer->setLabels (labels);
607
+ }
592
608
593
609
// all good if we made it here
594
610
is_setup = true ;
@@ -597,7 +613,7 @@ class TfLiteAudioOutput : public AudioPrint {
597
613
}
598
614
599
615
// / How many bytes can we write next
600
- int availableForWrite () { return feature_provider-> availableForWrite () ; }
616
+ int availableForWrite () { return DEFAULT_BUFFER_SIZE ; }
601
617
602
618
// / process the data in batches of max kMaxAudioSampleSize.
603
619
size_t write (const uint8_t * audio, size_t bytes) {
@@ -609,8 +625,7 @@ class TfLiteAudioOutput : public AudioPrint {
609
625
// we submit int16 data which will be reduced to 8bits so we can send
610
626
// double the amount - 2 channels will be recuced to 1 so we multiply by
611
627
// number of channels
612
- int maxBytes = feature_provider->kMaxAudioSampleSize * 2 *
613
- feature_provider->kAudioChannels ;
628
+ int maxBytes = feature_provider->kMaxAudioSampleSize * 2 * feature_provider->kAudioChannels ;
614
629
while (open > 0 ) {
615
630
int len = min (open, maxBytes);
616
631
result += process (audio + pos, len);
@@ -626,8 +641,11 @@ class TfLiteAudioOutput : public AudioPrint {
626
641
TfLiteTensor* model_input = nullptr ;
627
642
TfLiteAudioFeatureProvider* feature_provider = nullptr ;
628
643
RecognizeCommands<N>* recognizer = nullptr ;
629
- int32_t previous_time = 0 ;
644
+ int32_t current_time = 0 ;
645
+ int16_t total_slice_count = 0 ;
630
646
bool is_setup = false ;
647
+ bool use_all_ops_resolver = false ;
648
+
631
649
632
650
// Create an area of memory to use for input, output, and intermediate
633
651
// arrays. The size of this will depend on the model you're using, and may
@@ -637,7 +655,7 @@ class TfLiteAudioOutput : public AudioPrint {
637
655
int8_t * feature_buffer = nullptr ;
638
656
int8_t * model_input_buffer = nullptr ;
639
657
640
- bool setupModel (const unsigned char * model) {
658
+ bool setModel (const unsigned char * model) {
641
659
LOGD (LOG_METHOD);
642
660
p_model = tflite::GetModel (model);
643
661
if (p_model->version () != TFLITE_SCHEMA_VERSION) {
@@ -658,43 +676,49 @@ class TfLiteAudioOutput : public AudioPrint {
658
676
//
659
677
bool setupInterpreter () {
660
678
LOGD (LOG_METHOD);
661
- // tflite::AllOpsResolver resolver;
662
-
663
- // NOLINTNEXTLINE(runtime-global-variables)
664
- static tflite::MicroMutableOpResolver<4 > micro_op_resolver (error_reporter);
665
- if (micro_op_resolver.AddDepthwiseConv2D () != kTfLiteOk ) {
666
- return false ;
667
- }
668
- if (micro_op_resolver.AddFullyConnected () != kTfLiteOk ) {
669
- return false ;
670
- }
671
- if (micro_op_resolver.AddSoftmax () != kTfLiteOk ) {
672
- return false ;
673
- }
674
- if (micro_op_resolver.AddReshape () != kTfLiteOk ) {
675
- return false ;
679
+ if (use_all_ops_resolver){
680
+ tflite::AllOpsResolver resolver;
681
+ static tflite::MicroInterpreter static_interpreter (
682
+ p_model, resolver, tensor_arena, kTensorArenaSize ,
683
+ error_reporter);
684
+ interpreter = &static_interpreter;
685
+ } else {
686
+ // NOLINTNEXTLINE(runtime-global-variables)
687
+ static tflite::MicroMutableOpResolver<4 > micro_op_resolver (error_reporter);
688
+ if (micro_op_resolver.AddDepthwiseConv2D () != kTfLiteOk ) {
689
+ return false ;
690
+ }
691
+ if (micro_op_resolver.AddFullyConnected () != kTfLiteOk ) {
692
+ return false ;
693
+ }
694
+ if (micro_op_resolver.AddSoftmax () != kTfLiteOk ) {
695
+ return false ;
696
+ }
697
+ if (micro_op_resolver.AddReshape () != kTfLiteOk ) {
698
+ return false ;
699
+ }
700
+ // Build an interpreter to run the model with.
701
+ static tflite::MicroInterpreter static_interpreter (
702
+ p_model, micro_op_resolver, tensor_arena, kTensorArenaSize ,
703
+ error_reporter);
704
+ interpreter = &static_interpreter;
676
705
}
677
- // Build an interpreter to run the model with.
678
- static tflite::MicroInterpreter static_interpreter (
679
- p_model, micro_op_resolver, tensor_arena, kTensorArenaSize ,
680
- error_reporter);
681
- interpreter = &static_interpreter;
682
706
return true ;
683
707
}
684
708
685
709
// The name of this function is important for Arduino compatibility. Returns
686
710
// the number of bytes
687
711
size_t process (const uint8_t * audio, size_t bytes) {
688
712
LOGD (" process: %u" , (unsigned )bytes);
689
- // Fetch the spectrogram for the current time.
690
- int how_many_new_slices = feature_provider->write (audio, bytes);
713
+ // Update the spectrogram
714
+ total_slice_count + = feature_provider->write (audio, bytes);
691
715
692
- // If no new audio samples have been received since last time, don't
693
- // bother running the network model.
694
- if (how_many_new_slices == 0 ) {
716
+ // run network model only if we have the necessary slices
717
+ if (total_slice_count <= feature_provider->kSlicesToProcess ) {
695
718
return bytes;
696
719
}
697
- LOGI (" ->slices: %d" , how_many_new_slices);
720
+
721
+ LOGI (" ->slices: %d" , total_slice_count);
698
722
// Copy feature buffer to input tensor
699
723
for (int i = 0 ; i < feature_provider->featureElementCount (); i++) {
700
724
model_input_buffer[i] = feature_buffer[i];
@@ -709,12 +733,15 @@ class TfLiteAudioOutput : public AudioPrint {
709
733
710
734
// Obtain a pointer to the output tensor
711
735
TfLiteTensor* output = interpreter->output (0 );
736
+
737
+ // determine time
738
+ current_time += feature_provider->kFeatureSliceStrideMs * total_slice_count;
712
739
// Determine whether a command was recognized based on the output of
713
740
// inference
714
741
const char * found_command = nullptr ;
715
742
uint8_t score = 0 ;
716
743
bool is_new_command = false ;
717
- unsigned long current_time = millis ();
744
+
718
745
TfLiteStatus process_status = recognizer->ProcessLatestResults (
719
746
output, current_time, &found_command, &score, &is_new_command);
720
747
if (process_status != kTfLiteOk ) {
@@ -725,7 +752,7 @@ class TfLiteAudioOutput : public AudioPrint {
725
752
// implementation just prints to the error console, but you should replace
726
753
// this with your own function for a real application.
727
754
respondToCommand (found_command, score, is_new_command);
728
-
755
+ total_slice_count = 0 ;
729
756
// all processed
730
757
return bytes;
731
758
}
0 commit comments