1212
1313#include < gflags/gflags.h>
1414
15+ #include < executorch/extension/module/module.h>
16+ #include < executorch/extension/tensor/tensor_ptr_maker.h>
17+ #include < executorch/runtime/core/evalue.h>
18+
1519#include < executorch/extension/llm/runner/audio.h>
1620#include < executorch/extension/llm/runner/image.h>
1721#include < executorch/extension/llm/runner/llm_runner_helper.h>
@@ -36,6 +40,11 @@ DEFINE_string(prompt, "What is happening in this audio?", "Text prompt.");
3640
3741DEFINE_string (audio_path, " " , " Path to input audio file." );
3842
43+ DEFINE_string (
44+ processor_path,
45+ " " ,
46+ " Path to processor .pte file for raw audio processing." );
47+
3948DEFINE_double (
4049 temperature,
4150 0 .8f ,
@@ -50,10 +59,13 @@ DEFINE_bool(warmup, false, "Whether to run a warmup run.");
5059
5160namespace {
5261
62+ using ::executorch::extension::from_blob;
63+ using ::executorch::extension::Module;
5364using ::executorch::extension::llm::Image;
5465using ::executorch::extension::llm::make_image_input;
5566using ::executorch::extension::llm::make_text_input;
5667using ::executorch::extension::llm::MultimodalInput;
68+ using ::executorch::runtime::EValue;
5769
5870bool ends_with (const std::string& str, const std::string& suffix) {
5971 return str.size () >= suffix.size () &&
@@ -74,55 +86,185 @@ bool ends_with(const std::string& str, const std::string& suffix) {
7486 */
7587MultimodalInput loadPreprocessedAudio (const std::string& audio_path) {
7688 std::ifstream f (audio_path, std::ios::binary | std::ios::ate);
89+ if (!f.is_open ()) {
90+ ET_LOG (Error, " Failed to open audio file: %s" , audio_path.c_str ());
91+ throw std::runtime_error (" Failed to open audio file" );
92+ }
93+
94+ std::size_t n_floats = f.tellg () / sizeof (float );
95+ f.seekg (0 , std::ios::beg);
96+
7797 int32_t n_bins = 128 ;
7898 int32_t n_frames = 3000 ;
79- std::size_t n_floats =
80- f.tellg () / sizeof (float ); // Number of floats in the audio file.
81- f.seekg (0 , std::ios::beg);
99+
82100 int32_t batch_size = ceil (
83101 n_floats /
84102 (n_bins * n_frames)); // Batch in increments of n_frames, rounding up.
85- std::vector<float > audio_data (batch_size * n_bins * n_frames);
86- f.read (
87- reinterpret_cast <char *>(audio_data.data ()),
88- audio_data.size () * sizeof (float ));
89103
90- ET_LOG (Info, " audio_data len = %d " , audio_data. size () );
104+ ET_LOG (Info, " audio_data len = %zu " , n_floats );
91105
106+ // Create Audio multimodal input
92107 auto audio = std::make_unique<::executorch::extension::llm::Audio>();
93108 audio->batch_size = batch_size;
94109 audio->n_bins = n_bins;
95110 audio->n_frames = n_frames;
96- audio->data .resize (audio_data. size () * sizeof (float ));
97- std::memcpy (
98- audio-> data . data (), audio_data. data (), audio_data. size () * sizeof ( float ) );
111+ audio->data .resize (n_floats * sizeof (float ));
112+ f. read ( reinterpret_cast < char *>(audio-> data . data ()), n_floats * sizeof ( float ));
113+ f. close ( );
99114 return ::executorch::extension::llm::make_audio_input (std::move (*audio));
100115}
101116
102117/* *
103- * @brief Processes audio files for multimodal input
118+ * @brief Loads a .bin file into a tensor and processes it using a .pte
119+ * processor
104120 *
105- * Dispatches audio file processing based on file extension:
106- * - .bin files: Loads preprocessed mel spectrogram features directly
107- * - .wav/.mp3 files: Currently unsupported, throws runtime_error
121+ * This function loads raw audio data from a .bin file (similar to
122+ * loadPreprocessedAudio), creates a tensor from it, and then passes it through
123+ * a processor module loaded from a .pte file to generate processed audio
124+ * features.
108125 *
109- * This function provides a interface for different audio input formats
110- * and can be extended to support raw audio processing in the future.
126+ * @param audio_path Path to the .bin audio file
127+ * @param processor_path Path to the .pte processor file
128+ * @return MultimodalInput containing the processed audio data
129+ * @throws std::runtime_error if file loading or processing fails
130+ */
131+ MultimodalInput processRawAudioFile (
132+ const std::string& audio_path,
133+ const std::string& processor_path) {
134+ if (processor_path.empty ()) {
135+ ET_LOG (Error, " Processor path is required for raw audio processing" );
136+ throw std::runtime_error (
137+ " Processor path is required for raw audio processing" );
138+ }
139+
140+ // Load the audio processor .pte.
141+ std::unique_ptr<Module> processor_module;
142+ try {
143+ processor_module =
144+ std::make_unique<Module>(processor_path, Module::LoadMode::File);
145+ auto load_error = processor_module->load ();
146+ if (load_error != ::executorch::runtime::Error::Ok) {
147+ ET_LOG (
148+ Error,
149+ " Failed to load processor module from: %s" ,
150+ processor_path.c_str ());
151+ throw std::runtime_error (" Failed to load processor module" );
152+ }
153+ } catch (const std::exception& e) {
154+ ET_LOG (Error, " Exception while loading processor module: %s" , e.what ());
155+ throw std::runtime_error (" Exception while loading processor module" );
156+ }
157+
158+ // Load the audio data from file.
159+ std::ifstream f (audio_path, std::ios::binary | std::ios::ate);
160+ if (!f.is_open ()) {
161+ ET_LOG (Error, " Failed to open audio file: %s" , audio_path.c_str ());
162+ throw std::runtime_error (" Failed to open audio file" );
163+ }
164+
165+ std::size_t n_floats = f.tellg () / sizeof (float );
166+ f.seekg (0 , std::ios::beg);
167+
168+ std::vector<float > audio_data (n_floats);
169+ f.read (
170+ reinterpret_cast <char *>(audio_data.data ()),
171+ audio_data.size () * sizeof (float ));
172+ f.close ();
173+
174+ ET_LOG (
175+ Info, " Loaded .bin file: %s, %zu floats" , audio_path.c_str (), n_floats);
176+
177+ // Execute the processor
178+ std::vector<executorch::aten::SizesType> tensor_shape = {
179+ static_cast <executorch::aten::SizesType>(audio_data.size ())};
180+ auto input_tensor = from_blob (
181+ audio_data.data (), tensor_shape, ::executorch::aten::ScalarType::Float);
182+
183+ ET_LOG (Info, " Processing audio through processor module..." );
184+ auto result = processor_module->execute (" forward" , input_tensor);
185+ if (!result.ok ()) {
186+ ET_LOG (Error, " Failed to execute processor's forward method" );
187+ throw std::runtime_error (" Failed to execute processor forward method" );
188+ }
189+
190+ auto outputs = result.get ();
191+ if (outputs.empty ()) {
192+ ET_LOG (Error, " Processor returned no outputs" );
193+ throw std::runtime_error (" Processor returned no outputs" );
194+ }
195+
196+ // Extract processed audio features
197+ const auto & processed_tensor = outputs[0 ].toTensor ();
198+ const float * processed_data = processed_tensor.const_data_ptr <float >();
199+ const auto & sizes = processed_tensor.sizes ();
200+
201+ ET_LOG (
202+ Info,
203+ " Processed audio tensor shape: [%d, %d, %d]" ,
204+ static_cast <int >(sizes[0 ]),
205+ static_cast <int >(sizes[1 ]),
206+ static_cast <int >(sizes[2 ]));
207+
208+ // Create Audio multimodal input from processed features
209+ auto processed_audio =
210+ std::make_unique<::executorch::extension::llm::Audio>();
211+ processed_audio->batch_size =
212+ static_cast <int32_t >(sizes[0 ]); // Note: batching for s > 30 doesn't work
213+ // yet, so this will just be = 1.
214+ processed_audio->n_bins = static_cast <int32_t >(sizes[1 ]);
215+ processed_audio->n_frames =
216+ static_cast <int32_t >(sizes[2 ]); // And this will just be = 3000.
217+
218+ size_t total_elements = processed_audio->batch_size *
219+ processed_audio->n_bins * processed_audio->n_frames ;
220+ processed_audio->data .resize (total_elements * sizeof (float ));
221+ std::memcpy (
222+ processed_audio->data .data (),
223+ processed_data,
224+ total_elements * sizeof (float ));
225+
226+ ET_LOG (
227+ Info,
228+ " Created processed Audio: batch_size=%d, n_bins=%d, n_frames=%d" ,
229+ processed_audio->batch_size ,
230+ processed_audio->n_bins ,
231+ processed_audio->n_frames );
232+
233+ return ::executorch::extension::llm::make_audio_input (
234+ std::move (*processed_audio));
235+ }
236+
237+ /* *
238+ * @brief Processes audio files for multimodal input
239+ *
240+ * Dispatches audio file processing based on file extension and processor
241+ * availability:
242+ * - .bin files with processor: Loads raw audio from .bin and processes through
243+ * processor
244+ * - .bin files without processor: Loads preprocessed mel spectrogram features
245+ * directly
111246 *
112- * @param audio_path Path to the audio file
247+ * @param audio_path Path to the audio file (.bin)
248+ * @param processor_path Path to the processor .pte file (optional)
113249 * @return MultimodalInput containing the processed audio data
114250 * @throws std::runtime_error if file format is unsupported or processing fails
115251 */
116- MultimodalInput processAudioFile (const std::string& audio_path) {
252+ MultimodalInput processAudioFile (
253+ const std::string& audio_path,
254+ const std::string& processor_path = " " ) {
117255 if (ends_with (audio_path, " .bin" )) {
118- // Current behavior - load preprocessed audio stored as a binary file.
119- return loadPreprocessedAudio (audio_path);
120- } else if (ends_with (audio_path, " .wav" ) || ends_with (audio_path, " .mp3" )) {
121- // New: Process raw audio files - unsupported for now
122- ET_LOG (Error, " Raw audio file processing (.wav/.mp3) is not yet supported" );
123- throw std::runtime_error (" Raw audio file processing not supported" );
256+ if (!processor_path.empty ()) {
257+ // Process raw audio from .bin file through the processor
258+ return processRawAudioFile (audio_path, processor_path);
259+ } else {
260+ // Load preprocessed audio stored as a binary file (existing behavior)
261+ return loadPreprocessedAudio (audio_path);
262+ }
124263 } else {
125- ET_LOG (Error, " Unsupported audio file format: %s" , audio_path.c_str ());
264+ ET_LOG (
265+ Error,
266+ " Unsupported audio file format: %s (only .bin files are supported)" ,
267+ audio_path.c_str ());
126268 throw std::runtime_error (" Unsupported audio file format" );
127269 }
128270}
@@ -137,6 +279,7 @@ int32_t main(int32_t argc, char** argv) {
137279 const char * tokenizer_path = FLAGS_tokenizer_path.c_str ();
138280 const char * prompt = FLAGS_prompt.c_str ();
139281 const char * audio_path = FLAGS_audio_path.c_str ();
282+ const char * processor_path = FLAGS_processor_path.c_str ();
140283 float temperature = FLAGS_temperature;
141284 int32_t cpu_threads = FLAGS_cpu_threads;
142285 bool warmup = FLAGS_warmup;
@@ -180,7 +323,7 @@ int32_t main(int32_t argc, char** argv) {
180323 // Prepare inputs
181324 std::vector<MultimodalInput> inputs = {
182325 make_text_input (" <s>[INST][BEGIN_AUDIO]" ),
183- processAudioFile (audio_path),
326+ processAudioFile (audio_path, processor_path ),
184327 make_text_input (std::string (prompt) + " [/INST]" ),
185328 };
186329
0 commit comments