Skip to content

Commit ac9133b

Browse files
committed
Tensorflow DRAFT
1 parent 1cb09ae commit ac9133b

File tree

2 files changed

+76
-45
lines changed

2 files changed

+76
-45
lines changed

examples/sandbox/streams-audiokit-tf/streams-audiokit-tf.ino

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ const char* kCategoryLabels[4] = {
1414
"no",
1515
};
1616
StreamCopy copier(tfl, kit); // copy mic to tfl
17-
int channels = 2;
17+
int channels = 1;
1818
int samples_per_second = 16000;
1919

2020
void respondToCommand(const char* found_command, uint8_t score,
@@ -36,6 +36,10 @@ void setup() {
3636
cfg.input_device = AUDIO_HAL_ADC_INPUT_LINE2;
3737
cfg.channels = channels;
3838
cfg.sample_rate = samples_per_second;
39+
cfg.use_apll = false;
40+
cfg.auto_clear = false;
41+
cfg.buffer_size = 512;
42+
cfg.buffer_count = 16;
3943
kit.begin(cfg);
4044

4145
// Setup tensorflow

src/AudioLibs/TfLiteAudioOutput.h

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,8 @@ class TfLiteAudioFeatureProvider {
371371
int kFeatureSliceStrideMs = 20;
372372
int kFeatureSliceDurationMs = 30;
373373

374-
// Variables for the model's output categories.
375-
int kSilenceIndex = 0;
376-
int kUnknownIndex = 1;
374+
// number of new slices to collect before evaluating the model
375+
int kSlicesToProcess = 2;
377376

378377
// Callback method for result
379378
void (*respondToCommand)(const char* found_command, uint8_t score,
@@ -535,11 +534,22 @@ class TfLiteAudioOutput : public AudioPrint {
535534
public:
536535
TfLiteAudioOutput() {}
537536

538-
// The name of this function is important for Arduino compatibility.
537+
/// Optionally define your own recognizer
538+
void setRecognizer(RecognizeCommands<N>* p_recognizer) {
539+
recognizer = p_recognizer;
540+
}
541+
542+
/// Optionally define your own interpreter
543+
void setInterpreter(tflite::MicroInterpreter* p_interpreter) {
544+
this->interpreter = p_interpreter;
545+
}
546+
547+
/// Start the processing
539548
virtual bool begin(const unsigned char* model,
540549
TfLiteAudioFeatureProvider& featureProvider,
541-
const char** labels, int tensorArenaSize = 10 * 1024) {
550+
const char** labels, int tensorArenaSize = 10 * 1024, bool all_ops_resolver=false) {
542551
LOGD(LOG_METHOD);
552+
this->use_all_ops_resolver = all_ops_resolver;
543553
this->kTensorArenaSize = tensorArenaSize;
544554

545555
// setup the feature provider
@@ -552,12 +562,15 @@ class TfLiteAudioOutput : public AudioPrint {
552562

553563
// Map the model into a usable data structure. This doesn't involve any
554564
// copying or parsing, it's a very lightweight operation.
555-
if (!setupModel(model)) {
565+
if (!setModel(model)) {
556566
return false;
557567
}
558568

559-
if (!setupInterpreter()) {
560-
return false;
569+
// setup default interpreter if not assigned yet
570+
if (interpreter == nullptr) {
571+
if (!setupInterpreter()) {
572+
return false;
573+
}
561574
}
562575

563576
// Allocate memory from the tensor_arena for the model's tensors.
@@ -586,9 +599,12 @@ class TfLiteAudioOutput : public AudioPrint {
586599
return false;
587600
}
588601

589-
static RecognizeCommands<N> static_recognizer;
590-
recognizer = &static_recognizer;
591-
recognizer->setLabels(labels);
602+
// setup default recognizer if not defined
603+
if (recognizer == nullptr) {
604+
static RecognizeCommands<N> static_recognizer;
605+
recognizer = &static_recognizer;
606+
recognizer->setLabels(labels);
607+
}
592608

593609
// all good if we made it here
594610
is_setup = true;
@@ -597,7 +613,7 @@ class TfLiteAudioOutput : public AudioPrint {
597613
}
598614

599615
/// How many bytes can we write next
600-
int availableForWrite() { return feature_provider->availableForWrite(); }
616+
int availableForWrite() { return DEFAULT_BUFFER_SIZE; }
601617

602618
/// process the data in batches of max kMaxAudioSampleSize.
603619
size_t write(const uint8_t* audio, size_t bytes) {
@@ -609,8 +625,7 @@ class TfLiteAudioOutput : public AudioPrint {
609625
// we submit int16 data which will be reduced to 8bits so we can send
610626
// double the amount - 2 channels will be recuced to 1 so we multiply by
611627
// number of channels
612-
int maxBytes = feature_provider->kMaxAudioSampleSize * 2 *
613-
feature_provider->kAudioChannels;
628+
int maxBytes = feature_provider->kMaxAudioSampleSize * 2 * feature_provider->kAudioChannels;
614629
while (open > 0) {
615630
int len = min(open, maxBytes);
616631
result += process(audio + pos, len);
@@ -626,8 +641,11 @@ class TfLiteAudioOutput : public AudioPrint {
626641
TfLiteTensor* model_input = nullptr;
627642
TfLiteAudioFeatureProvider* feature_provider = nullptr;
628643
RecognizeCommands<N>* recognizer = nullptr;
629-
int32_t previous_time = 0;
644+
int32_t current_time = 0;
645+
int16_t total_slice_count = 0;
630646
bool is_setup = false;
647+
bool use_all_ops_resolver = false;
648+
631649

632650
// Create an area of memory to use for input, output, and intermediate
633651
// arrays. The size of this will depend on the model you're using, and may
@@ -637,7 +655,7 @@ class TfLiteAudioOutput : public AudioPrint {
637655
int8_t* feature_buffer = nullptr;
638656
int8_t* model_input_buffer = nullptr;
639657

640-
bool setupModel(const unsigned char* model) {
658+
bool setModel(const unsigned char* model) {
641659
LOGD(LOG_METHOD);
642660
p_model = tflite::GetModel(model);
643661
if (p_model->version() != TFLITE_SCHEMA_VERSION) {
@@ -658,43 +676,49 @@ class TfLiteAudioOutput : public AudioPrint {
658676
//
659677
bool setupInterpreter() {
660678
LOGD(LOG_METHOD);
661-
// tflite::AllOpsResolver resolver;
662-
663-
// NOLINTNEXTLINE(runtime-global-variables)
664-
static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter);
665-
if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) {
666-
return false;
667-
}
668-
if (micro_op_resolver.AddFullyConnected() != kTfLiteOk) {
669-
return false;
670-
}
671-
if (micro_op_resolver.AddSoftmax() != kTfLiteOk) {
672-
return false;
673-
}
674-
if (micro_op_resolver.AddReshape() != kTfLiteOk) {
675-
return false;
679+
if (use_all_ops_resolver){
680+
tflite::AllOpsResolver resolver;
681+
static tflite::MicroInterpreter static_interpreter(
682+
p_model, resolver, tensor_arena, kTensorArenaSize,
683+
error_reporter);
684+
interpreter = &static_interpreter;
685+
} else {
686+
// NOLINTNEXTLINE(runtime-global-variables)
687+
static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter);
688+
if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) {
689+
return false;
690+
}
691+
if (micro_op_resolver.AddFullyConnected() != kTfLiteOk) {
692+
return false;
693+
}
694+
if (micro_op_resolver.AddSoftmax() != kTfLiteOk) {
695+
return false;
696+
}
697+
if (micro_op_resolver.AddReshape() != kTfLiteOk) {
698+
return false;
699+
}
700+
// Build an interpreter to run the model with.
701+
static tflite::MicroInterpreter static_interpreter(
702+
p_model, micro_op_resolver, tensor_arena, kTensorArenaSize,
703+
error_reporter);
704+
interpreter = &static_interpreter;
676705
}
677-
// Build an interpreter to run the model with.
678-
static tflite::MicroInterpreter static_interpreter(
679-
p_model, micro_op_resolver, tensor_arena, kTensorArenaSize,
680-
error_reporter);
681-
interpreter = &static_interpreter;
682706
return true;
683707
}
684708

685709
// The name of this function is important for Arduino compatibility. Returns
686710
// the number of bytes
687711
size_t process(const uint8_t* audio, size_t bytes) {
688712
LOGD("process: %u", (unsigned)bytes);
689-
// Fetch the spectrogram for the current time.
690-
int how_many_new_slices = feature_provider->write(audio, bytes);
713+
// Update the spectrogram
714+
total_slice_count += feature_provider->write(audio, bytes);
691715

692-
// If no new audio samples have been received since last time, don't
693-
// bother running the network model.
694-
if (how_many_new_slices == 0) {
716+
// run network model only if we have the necessary slices
717+
if (total_slice_count <= feature_provider->kSlicesToProcess) {
695718
return bytes;
696719
}
697-
LOGI("->slices: %d", how_many_new_slices);
720+
721+
LOGI("->slices: %d", total_slice_count);
698722
// Copy feature buffer to input tensor
699723
for (int i = 0; i < feature_provider->featureElementCount(); i++) {
700724
model_input_buffer[i] = feature_buffer[i];
@@ -709,12 +733,15 @@ class TfLiteAudioOutput : public AudioPrint {
709733

710734
// Obtain a pointer to the output tensor
711735
TfLiteTensor* output = interpreter->output(0);
736+
737+
// determine time
738+
current_time += feature_provider->kFeatureSliceStrideMs * total_slice_count;
712739
// Determine whether a command was recognized based on the output of
713740
// inference
714741
const char* found_command = nullptr;
715742
uint8_t score = 0;
716743
bool is_new_command = false;
717-
unsigned long current_time = millis();
744+
718745
TfLiteStatus process_status = recognizer->ProcessLatestResults(
719746
output, current_time, &found_command, &score, &is_new_command);
720747
if (process_status != kTfLiteOk) {
@@ -725,7 +752,7 @@ class TfLiteAudioOutput : public AudioPrint {
725752
// implementation just prints to the error console, but you should replace
726753
// this with your own function for a real application.
727754
respondToCommand(found_command, score, is_new_command);
728-
755+
total_slice_count = 0;
729756
// all processed
730757
return bytes;
731758
}

0 commit comments

Comments
 (0)