Skip to content

Commit c8515e1

Browse files
authored
[Web] Replace string with TVMFFIByteArray* to avoid memory issues (#18467)
Passing in a string to `ArrayDecodeStorage` via the packed function definition led to memory issues for larger models (such as `gemma-2-9b-it-q4f32_1-MLC`). Replacing string with TVMFFIByteArray* fixes this issue and also alleviates the stack pollution issue discussed in an earlier PR (#18415). Note that this does not completely fix generation for q0f32 models.
1 parent a8c7580 commit c8515e1

File tree

2 files changed

+37
-42
lines changed

2 files changed

+37
-42
lines changed

web/emcc/wasm_runtime.cc

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -125,24 +125,25 @@ TVM_FFI_STATIC_INIT_BLOCK() {
125125
});
126126
}
127127

128-
void ArrayDecodeStorage(Tensor cpu_arr, std::string bytes, std::string format, std::string dtype) {
128+
void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::string& format,
129+
const std::string& dtype) {
130+
ICHECK_NE(bytes, nullptr);
131+
const char* byte_data = bytes->data;
132+
const size_t byte_size = bytes->size;
129133
if (format == "f32-to-bf16" && dtype == "float32") {
130-
std::vector<uint16_t> buffer(bytes.length() / 2);
131-
std::memcpy(buffer.data(), bytes.data(), buffer.size() * 2);
132-
// decode bf16 to f32
133-
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(buffer.data());
134+
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
134135
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
135136
ICHECK(cpu_arr.IsContiguous());
136137
size_t size = 1;
137138
for (int i = 0; i < cpu_arr->ndim; ++i) {
138139
size *= cpu_arr->shape[i];
139140
}
140-
ICHECK_EQ(size, bytes.length() / 2);
141+
ICHECK_EQ(size, byte_size / 2);
141142
for (size_t i = 0; i < size; ++i) {
142143
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
143144
}
144145
} else {
145-
cpu_arr.CopyFromBytes(bytes.data(), bytes.length());
146+
cpu_arr.CopyFromBytes(byte_data, byte_size);
146147
}
147148
}
148149

@@ -151,16 +152,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
151152
refl::GlobalDef().def_packed(
152153
"tvmjs.array.decode_storage", [](ffi::PackedArgs args, ffi::Any* ret) {
153154
Tensor cpu_arr = args[0].cast<Tensor>();
154-
auto bytes = args[1].cast<ffi::Bytes>();
155+
TVMFFIByteArray* bytes = args[1].cast<TVMFFIByteArray*>();
155156
std::string format = args[2].cast<ffi::String>().operator std::string();
156157
std::string dtype = args[3].cast<ffi::String>().operator std::string();
157158
ArrayDecodeStorage(cpu_arr, bytes, format, dtype);
158-
if (ret != nullptr) {
159-
auto* ret_data = reinterpret_cast<TVMFFIAny*>(ret);
160-
ret_data->type_index = TVMFFITypeIndex::kTVMFFINone;
161-
ret_data->zero_padding = 0;
162-
ret_data->v_int64 = 0;
163-
}
164159
});
165160
}
166161

web/package-lock.json

Lines changed: 28 additions & 28 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)