Skip to content

Commit a2b997a

Browse files
committed
Merge remote-tracking branch 'origin/main' into jni-layer-llama-1
2 parents d6b957d + b759ae8 commit a2b997a

File tree

5 files changed

+584
-148
lines changed

5 files changed

+584
-148
lines changed

backends/arm/operators/op_view.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
47+
[
48+
ts.DType.INT8,
49+
ts.DType.INT16,
50+
ts.DType.INT32,
51+
ts.DType.FP32,
52+
ts.DType.BOOL,
53+
],
4854
output.tosa_spec,
4955
)
5056

backends/arm/test/ops/test_linear.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from typing import Tuple
1010

1111
import pytest
12-
1312
import torch
14-
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1518

1619
from executorch.backends.arm.test.tester.test_pipeline import (
1720
EthosU55PipelineINT,
@@ -20,6 +23,8 @@
2023
TosaPipelineINT,
2124
VgfPipeline,
2225
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2328

2429
aten_op = "torch.ops.aten.linear.default"
2530

@@ -143,7 +148,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
143148
pipeline.run()
144149

145150

146-
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
147151
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
148152
def test_linear_tosa_INT(test_data: torch.Tensor):
149153
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -243,3 +247,64 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
243247
per_channel_quantization=per_channel_quantization,
244248
)
245249
pipeline.run()
250+
251+
252+
def get_symmetric_a16w8_linear_quantizer(
253+
u55_config=False, per_channel_quantization=False
254+
):
255+
tosa_version = conftest.get_option("tosa_version")
256+
tosa_profiles = {
257+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
258+
}
259+
260+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
261+
quantizer.set_global(
262+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
263+
)
264+
quantizer.set_module_type(
265+
torch.nn.Linear,
266+
get_symmetric_a16w8_quantization_config(
267+
is_per_channel=per_channel_quantization
268+
),
269+
)
270+
271+
return Quantize(
272+
quantizer,
273+
get_symmetric_a16w8_quantization_config(
274+
is_per_channel=per_channel_quantization
275+
),
276+
)
277+
278+
279+
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
280+
@pytest.mark.xfail(
281+
reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph"
282+
)
283+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
284+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
285+
test_data, out_features, has_bias, per_channel_quantization = test_data()
286+
in_features = test_data.shape[-1]
287+
288+
# Create pipeline with custom 16A8W quantization config
289+
pipeline = TosaPipelineINT[input_t1](
290+
Linear(
291+
in_features=in_features,
292+
out_features=out_features,
293+
bias=has_bias,
294+
),
295+
(test_data,),
296+
aten_op,
297+
exir_op=[],
298+
per_channel_quantization=per_channel_quantization,
299+
use_to_edge_transform_and_lower=True,
300+
tosa_extensions=["int16"],
301+
)
302+
303+
pipeline.change_args(
304+
"quantize",
305+
get_symmetric_a16w8_linear_quantizer(
306+
per_channel_quantization=per_channel_quantization
307+
),
308+
)
309+
# Run the pipeline
310+
pipeline.run()

examples/models/voxtral/multimodal.cpp

Lines changed: 174 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
#include <gflags/gflags.h>
1414

15+
#include <executorch/extension/module/module.h>
16+
#include <executorch/extension/tensor/tensor_ptr_maker.h>
17+
#include <executorch/runtime/core/evalue.h>
18+
1519
#include <executorch/extension/llm/runner/audio.h>
1620
#include <executorch/extension/llm/runner/image.h>
1721
#include <executorch/extension/llm/runner/llm_runner_helper.h>
@@ -36,6 +40,11 @@ DEFINE_string(prompt, "What is happening in this audio?", "Text prompt.");
3640

3741
DEFINE_string(audio_path, "", "Path to input audio file.");
3842

43+
DEFINE_string(
44+
processor_path,
45+
"",
46+
"Path to processor .pte file for raw audio processing.");
47+
3948
DEFINE_double(
4049
temperature,
4150
0.8f,
@@ -50,16 +59,48 @@ DEFINE_bool(warmup, false, "Whether to run a warmup run.");
5059

5160
namespace {
5261

62+
using ::executorch::extension::from_blob;
63+
using ::executorch::extension::Module;
5364
using ::executorch::extension::llm::Image;
5465
using ::executorch::extension::llm::make_image_input;
5566
using ::executorch::extension::llm::make_text_input;
5667
using ::executorch::extension::llm::MultimodalInput;
68+
using ::executorch::runtime::EValue;
5769

5870
bool ends_with(const std::string& str, const std::string& suffix) {
5971
return str.size() >= suffix.size() &&
6072
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
6173
}
6274

75+
/**
76+
* @brief Loads float data from a binary file
77+
*
78+
* @param audio_path Path to the binary audio file (.bin)
79+
* @return Vector of float data loaded from the file
80+
* @throws std::runtime_error if file loading fails
81+
*/
82+
std::vector<float> loadBinaryFloatData(const std::string& audio_path) {
83+
std::ifstream f(audio_path, std::ios::binary | std::ios::ate);
84+
if (!f.is_open()) {
85+
ET_LOG(Error, "Failed to open audio file: %s", audio_path.c_str());
86+
throw std::runtime_error("Failed to open audio file");
87+
}
88+
89+
std::size_t n_floats =
90+
f.tellg() / sizeof(float); // Number of floats in the audio file
91+
f.seekg(0, std::ios::beg);
92+
93+
std::vector<float> audio_data(n_floats);
94+
f.read(
95+
reinterpret_cast<char*>(audio_data.data()),
96+
audio_data.size() * sizeof(float));
97+
f.close();
98+
99+
ET_LOG(
100+
Info, "Loaded .bin file: %s, %zu floats", audio_path.c_str(), n_floats);
101+
return audio_data;
102+
}
103+
63104
/**
64105
* @brief Loads preprocessed audio data from a binary file
65106
*
@@ -73,22 +114,19 @@ bool ends_with(const std::string& str, const std::string& suffix) {
73114
* @return MultimodalInput containing the loaded audio data
74115
*/
75116
MultimodalInput loadPreprocessedAudio(const std::string& audio_path) {
76-
std::ifstream f(audio_path, std::ios::binary | std::ios::ate);
117+
std::vector<float> audio_data = loadBinaryFloatData(audio_path);
118+
77119
int32_t n_bins = 128;
78120
int32_t n_frames = 3000;
79-
std::size_t n_floats =
80-
f.tellg() / sizeof(float); // Number of floats in the audio file.
81-
f.seekg(0, std::ios::beg);
121+
122+
std::size_t n_floats = audio_data.size();
82123
int32_t batch_size = ceil(
83124
n_floats /
84125
(n_bins * n_frames)); // Batch in increments of n_frames, rounding up.
85-
std::vector<float> audio_data(batch_size * n_bins * n_frames);
86-
f.read(
87-
reinterpret_cast<char*>(audio_data.data()),
88-
audio_data.size() * sizeof(float));
89126

90127
ET_LOG(Info, "audio_data len = %d", audio_data.size());
91128

129+
// Create Audio multimodal input
92130
auto audio = std::make_unique<::executorch::extension::llm::Audio>();
93131
audio->batch_size = batch_size;
94132
audio->n_bins = n_bins;
@@ -100,29 +138,140 @@ MultimodalInput loadPreprocessedAudio(const std::string& audio_path) {
100138
}
101139

102140
/**
103-
* @brief Processes audio files for multimodal input
141+
* @brief Loads a .bin file into a tensor and processes it using a .pte
142+
* processor
104143
*
105-
* Dispatches audio file processing based on file extension:
106-
* - .bin files: Loads preprocessed mel spectrogram features directly
107-
* - .wav/.mp3 files: Currently unsupported, throws runtime_error
144+
* This function loads raw audio data from a .bin file (similar to
145+
* loadPreprocessedAudio), creates a tensor from it, and then passes it through
146+
* a processor module loaded from a .pte file to generate processed audio
147+
* features.
148+
*
149+
* @param audio_path Path to the .bin audio file
150+
* @param processor_path Path to the .pte processor file
151+
* @return MultimodalInput containing the processed audio data
152+
* @throws std::runtime_error if file loading or processing fails
153+
*/
154+
MultimodalInput processRawAudioFile(
155+
const std::string& audio_path,
156+
const std::string& processor_path) {
157+
if (processor_path.empty()) {
158+
ET_LOG(Error, "Processor path is required for raw audio processing");
159+
throw std::runtime_error(
160+
"Processor path is required for raw audio processing");
161+
}
162+
163+
// Load the audio processor .pte.
164+
std::unique_ptr<Module> processor_module;
165+
try {
166+
processor_module =
167+
std::make_unique<Module>(processor_path, Module::LoadMode::File);
168+
auto load_error = processor_module->load();
169+
if (load_error != ::executorch::runtime::Error::Ok) {
170+
ET_LOG(
171+
Error,
172+
"Failed to load processor module from: %s",
173+
processor_path.c_str());
174+
throw std::runtime_error("Failed to load processor module");
175+
}
176+
} catch (const std::exception& e) {
177+
ET_LOG(Error, "Exception while loading processor module: %s", e.what());
178+
throw std::runtime_error("Exception while loading processor module");
179+
}
180+
181+
// Load the audio data from file.
182+
std::vector<float> audio_data = loadBinaryFloatData(audio_path);
183+
184+
// Execute the processor
185+
std::vector<executorch::aten::SizesType> tensor_shape = {
186+
static_cast<executorch::aten::SizesType>(audio_data.size())};
187+
auto input_tensor = from_blob(
188+
audio_data.data(), tensor_shape, ::executorch::aten::ScalarType::Float);
189+
190+
ET_LOG(Info, "Processing audio through processor module...");
191+
auto result = processor_module->execute("forward", input_tensor);
192+
if (!result.ok()) {
193+
ET_LOG(Error, "Failed to execute processor's forward method");
194+
throw std::runtime_error("Failed to execute processor forward method");
195+
}
196+
197+
auto outputs = result.get();
198+
if (outputs.empty()) {
199+
ET_LOG(Error, "Processor returned no outputs");
200+
throw std::runtime_error("Processor returned no outputs");
201+
}
202+
203+
// Extract processed audio features
204+
const auto& processed_tensor = outputs[0].toTensor();
205+
const float* processed_data = processed_tensor.const_data_ptr<float>();
206+
const auto& sizes = processed_tensor.sizes();
207+
208+
ET_LOG(
209+
Info,
210+
"Processed audio tensor shape: [%d, %d, %d]",
211+
static_cast<int>(sizes[0]),
212+
static_cast<int>(sizes[1]),
213+
static_cast<int>(sizes[2]));
214+
215+
// Create Audio multimodal input from processed features
216+
auto processed_audio =
217+
std::make_unique<::executorch::extension::llm::Audio>();
218+
processed_audio->batch_size =
219+
static_cast<int32_t>(sizes[0]); // Note: batching for s > 30 doesn't work
220+
// yet, so this will just be = 1.
221+
processed_audio->n_bins = static_cast<int32_t>(sizes[1]);
222+
processed_audio->n_frames =
223+
static_cast<int32_t>(sizes[2]); // And this will just be = 3000.
224+
225+
size_t total_elements = processed_audio->batch_size *
226+
processed_audio->n_bins * processed_audio->n_frames;
227+
processed_audio->data.resize(total_elements * sizeof(float));
228+
std::memcpy(
229+
processed_audio->data.data(),
230+
processed_data,
231+
total_elements * sizeof(float));
232+
233+
ET_LOG(
234+
Info,
235+
"Created processed Audio: batch_size=%d, n_bins=%d, n_frames=%d",
236+
processed_audio->batch_size,
237+
processed_audio->n_bins,
238+
processed_audio->n_frames);
239+
240+
return ::executorch::extension::llm::make_audio_input(
241+
std::move(*processed_audio));
242+
}
243+
244+
/**
245+
* @brief Processes audio files for multimodal input
108246
*
109-
* This function provides a interface for different audio input formats
110-
* and can be extended to support raw audio processing in the future.
247+
* Dispatches audio file processing based on file extension and processor
248+
* availability:
249+
* - .bin files with processor: Loads raw audio from .bin and processes through
250+
* processor
251+
* - .bin files without processor: Loads preprocessed mel spectrogram features
252+
* directly
111253
*
112-
* @param audio_path Path to the audio file
254+
* @param audio_path Path to the audio file (.bin)
255+
* @param processor_path Path to the processor .pte file (optional)
113256
* @return MultimodalInput containing the processed audio data
114257
* @throws std::runtime_error if file format is unsupported or processing fails
115258
*/
116-
MultimodalInput processAudioFile(const std::string& audio_path) {
259+
MultimodalInput processAudioFile(
260+
const std::string& audio_path,
261+
const std::string& processor_path = "") {
117262
if (ends_with(audio_path, ".bin")) {
118-
// Current behavior - load preprocessed audio stored as a binary file.
119-
return loadPreprocessedAudio(audio_path);
120-
} else if (ends_with(audio_path, ".wav") || ends_with(audio_path, ".mp3")) {
121-
// New: Process raw audio files - unsupported for now
122-
ET_LOG(Error, "Raw audio file processing (.wav/.mp3) is not yet supported");
123-
throw std::runtime_error("Raw audio file processing not supported");
263+
if (!processor_path.empty()) {
264+
// Process raw audio from .bin file through the processor
265+
return processRawAudioFile(audio_path, processor_path);
266+
} else {
267+
// Load preprocessed audio stored as a binary file (existing behavior)
268+
return loadPreprocessedAudio(audio_path);
269+
}
124270
} else {
125-
ET_LOG(Error, "Unsupported audio file format: %s", audio_path.c_str());
271+
ET_LOG(
272+
Error,
273+
"Unsupported audio file format: %s (only .bin files are supported)",
274+
audio_path.c_str());
126275
throw std::runtime_error("Unsupported audio file format");
127276
}
128277
}
@@ -137,6 +286,7 @@ int32_t main(int32_t argc, char** argv) {
137286
const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
138287
const char* prompt = FLAGS_prompt.c_str();
139288
const char* audio_path = FLAGS_audio_path.c_str();
289+
const char* processor_path = FLAGS_processor_path.c_str();
140290
float temperature = FLAGS_temperature;
141291
int32_t cpu_threads = FLAGS_cpu_threads;
142292
bool warmup = FLAGS_warmup;
@@ -184,7 +334,7 @@ int32_t main(int32_t argc, char** argv) {
184334
inputs.emplace_back(make_text_input("<s>[INST][BEGIN_AUDIO]"));
185335

186336
// 2. Add audio input
187-
inputs.emplace_back(processAudioFile(audio_path));
337+
inputs.emplace_back(processAudioFile(audio_path, processor_path));
188338

189339
// 3. Add text input (the actual user-submitted prompt)
190340
inputs.emplace_back(make_text_input(std::string(prompt) + "[/INST]"));

0 commit comments

Comments
 (0)