Skip to content

Commit 366c763

Browse files
committed
Make Voxtral work
1 parent c838eee commit 366c763

File tree

5 files changed

+28
-9
lines changed

5 files changed

+28
-9
lines changed

backends/aoti/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
3434
// Convert based on known PyTorch dtype codes (without CUDA-specific
3535
// dependency)
3636
switch (dtype) {
37+
case 4: // PyTorch's int64 dtype code
38+
return executorch::aten::ScalarType::Long;
3739
case 6: // PyTorch's float32 dtype code
3840
return executorch::aten::ScalarType::Float;
3941
case 15: // PyTorch's bfloat16 dtype code

backends/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ executorch_target_link_options_shared_lib(aoti_cuda)
6363

6464
# Add runtime
6565
add_executable(voxtral_runner tests/voxtral_runner.cpp)
66-
target_link_libraries(voxtral_runner PUBLIC aoti_cuda extension_module_static extension_flat_tensor)
66+
target_link_libraries(voxtral_runner PUBLIC aoti_cuda extension_module_static extension_flat_tensor portable_ops_lib)
6767

6868
install(
6969
TARGETS aoti_cuda

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,20 @@ class CudaBackend final : public ::executorch::runtime::BackendInterface {
131131
// Generate dynamic temporary file path
132132
filesystem::path temp_dir = filesystem::temp_directory_path();
133133
filesystem::path so_path =
134-
temp_dir / ("aoti_cuda_" + to_string(getpid()) + ".so");
134+
temp_dir / (so_blob_key + to_string(getpid()) + ".so");
135135

136136
// Create a temporary file
137137
ofstream outfile(so_path.c_str(), ios::binary);
138138

139139
// Write the ELF buffer to the temporary file
140+
ET_LOG(
141+
Info,
142+
"Writing %zu bytes to %s",
143+
aoti_cuda_buffer->size(),
144+
so_path.c_str());
140145
outfile.write(
141-
(char*)aoti_cuda_buffer->data(),
142-
sizeof(void*) * aoti_cuda_buffer->size());
146+
static_cast<const char*>(aoti_cuda_buffer->data()),
147+
aoti_cuda_buffer->size());
143148

144149
// Finish writing the file to disk
145150
outfile.close();

backends/cuda/runtime/shims/utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ namespace cuda {
4040

4141
// Enum for supported data types in et-cuda backend
4242
enum class SupportedDTypes : int32_t {
43+
INT64 = 4, // PyTorch's int64 dtype code
4344
FLOAT32 = 6, // PyTorch's float32 dtype code
4445
BFLOAT16 = 15, // PyTorch's bfloat16 dtype code
4546
};
@@ -100,6 +101,7 @@ using AOTITorchError = Error;
100101
// Helper function to check if a dtype is supported in ET CUDA backend
101102
inline bool is_dtype_supported_in_et_cuda(int32_t dtype) {
102103
switch (dtype) {
104+
case static_cast<int32_t>(SupportedDTypes::INT64):
103105
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
104106
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
105107
return true;
@@ -113,8 +115,9 @@ inline AOTITorchError validate_dtype(int32_t dtype) {
113115
ET_CHECK_OR_RETURN_ERROR(
114116
is_dtype_supported_in_et_cuda(dtype),
115117
InvalidArgument,
116-
"Unsupported dtype: %d. Supported dtypes: %d (float32), %d (bfloat16)",
118+
"Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)",
117119
dtype,
120+
static_cast<int32_t>(SupportedDTypes::INT64),
118121
static_cast<int32_t>(SupportedDTypes::FLOAT32),
119122
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
120123

backends/cuda/tests/voxtral_runner.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ int main(int argc, char** argv) {
136136

137137
const TensorPtr audio_input = create_audio_input();
138138
std::vector<EValue> inputs;
139-
inputs.emplace_back(audio_input);
139+
std::vector<TensorPtr> owned_inputs;
140+
owned_inputs.emplace_back(audio_input);
141+
inputs.emplace_back(*audio_input);
140142

141143
const auto run_start = Clock::now();
142144
Result<std::vector<EValue>> output_result =
@@ -171,7 +173,9 @@ int main(int argc, char** argv) {
171173

172174
const TensorPtr token_ids = create_token_ids_input();
173175
std::vector<EValue> inputs;
174-
inputs.emplace_back(token_ids);
176+
std::vector<TensorPtr> owned_inputs;
177+
owned_inputs.emplace_back(token_ids);
178+
inputs.emplace_back(*token_ids);
175179

176180
const auto run_start = Clock::now();
177181
auto token_output_result = module.execute("token_embedding", inputs);
@@ -203,17 +207,22 @@ int main(int argc, char** argv) {
203207
text_timing.load_ms = load_ms;
204208

205209
std::vector<EValue> inputs;
210+
std::vector<TensorPtr> owned_inputs;
206211
if (token_executed) {
207212
if (token_output.isTensor()) {
208213
inputs.emplace_back(token_output);
209214
}
210215
}
211216

212217
if (inputs.empty()) {
213-
inputs.emplace_back(create_fallback_text_embedding());
218+
auto fallback_embedding = create_fallback_text_embedding();
219+
owned_inputs.emplace_back(fallback_embedding);
220+
inputs.emplace_back(*fallback_embedding);
214221
}
215222

216-
inputs.emplace_back(create_positions_input());
223+
auto positions = create_positions_input();
224+
owned_inputs.emplace_back(positions);
225+
inputs.emplace_back(*positions);
217226

218227
const auto run_start = Clock::now();
219228
Result<std::vector<EValue>> output_result =

0 commit comments

Comments
 (0)