Skip to content

Commit 0481955

Browse files
committed
fix: Validate input type to make sure a TypedArray is passed
1 parent cbb76b9 commit 0481955

File tree

3 files changed

+37
-18
lines changed

3 files changed

+37
-18
lines changed

cpp/TensorHelpers.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ size_t TensorHelpers::getTFLTensorDataTypeSize(TfLiteType dataType) {
110110
return sizeof(uint16_t);
111111
default:
112112
[[unlikely]];
113-
throw std::runtime_error("Unsupported output data type! " + dataTypeToString(dataType));
113+
throw std::runtime_error("TFLite: Unsupported output data type! " +
114+
dataTypeToString(dataType));
114115
}
115116
}
116117

@@ -152,7 +153,8 @@ TypedArrayBase TensorHelpers::createJSBufferForTensor(jsi::Runtime& runtime,
152153
return TypedArray<TypedArrayKind::Uint32Array>(runtime, size);
153154
default:
154155
[[unlikely]];
155-
throw std::runtime_error("Unsupported tensor data type! " + dataTypeToString(dataType));
156+
throw std::runtime_error("TFLite: Unsupported tensor data type! " +
157+
dataTypeToString(dataType));
156158
}
157159
}
158160

@@ -164,7 +166,7 @@ void TensorHelpers::updateJSBufferFromTensor(jsi::Runtime& runtime, TypedArrayBa
164166
void* data = TfLiteTensorData(tensor);
165167
if (data == nullptr) {
166168
[[unlikely]];
167-
throw std::runtime_error("Failed to get data from tensor \"" + name + "\"!");
169+
throw std::runtime_error("TFLite: Failed to get data from tensor \"" + name + "\"!");
168170
}
169171

170172
// count of bytes, may be larger than count of numbers (e.g. for float32)
@@ -213,19 +215,21 @@ void TensorHelpers::updateJSBufferFromTensor(jsi::Runtime& runtime, TypedArrayBa
213215
break;
214216
default:
215217
[[unlikely]];
216-
throw jsi::JSError(runtime, "Unsupported output data type! " + dataTypeToString(dataType));
218+
throw jsi::JSError(runtime,
219+
"TFLite: Unsupported output data type! " + dataTypeToString(dataType));
217220
}
218221
}
219222

220223
void TensorHelpers::updateTensorFromJSBuffer(jsi::Runtime& runtime, TfLiteTensor* tensor,
221224
TypedArrayBase& jsBuffer) {
222225
#if DEBUG
226+
// Validate data-type
223227
TypedArrayKind kind = jsBuffer.getKind(runtime);
224228
TfLiteType receivedType = getTFLDataTypeForTypedArrayKind(kind);
225229
TfLiteType expectedType = TfLiteTensorType(tensor);
226230
if (receivedType != expectedType) {
227231
[[unlikely]];
228-
throw std::runtime_error("Invalid input type! Model expected " +
232+
throw std::runtime_error("TFLite: Invalid input type! Model expected " +
229233
dataTypeToString(expectedType) + ", but received " +
230234
dataTypeToString(receivedType) + "!");
231235
}
@@ -235,11 +239,12 @@ void TensorHelpers::updateTensorFromJSBuffer(jsi::Runtime& runtime, TfLiteTensor
235239
jsi::ArrayBuffer buffer = jsBuffer.getBuffer(runtime);
236240

237241
#if DEBUG
242+
// Validate size
238243
int inputBufferSize = buffer.size(runtime);
239244
int tensorSize = getTensorTotalLength(tensor) * getTFLTensorDataTypeSize(tensor->type);
240245
if (tensorSize != inputBufferSize) {
241246
[[unlikely]];
242-
throw std::runtime_error("Input Buffer size (" + std::to_string(inputBufferSize) +
247+
throw std::runtime_error("TFLite: Input Buffer size (" + std::to_string(inputBufferSize) +
243248
") does not match the Input Tensor's expected size (" +
244249
std::to_string(tensorSize) +
245250
")! Make sure to resize the input values accordingly.");

cpp/TensorflowPlugin.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,9 @@ TensorflowPlugin::TensorflowPlugin(TfLiteInterpreter* interpreter, Buffer model,
171171
TfLiteStatus status = TfLiteInterpreterAllocateTensors(_interpreter);
172172
if (status != kTfLiteOk) {
173173
[[unlikely]];
174-
throw std::runtime_error("Failed to allocate memory for input/output tensors! Status: " +
175-
tfLiteStatusToString(status));
174+
throw std::runtime_error(
175+
"TFLite: Failed to allocate memory for input/output tensors! Status: " +
176+
tfLiteStatusToString(status));
176177
}
177178

178179
log("Successfully created Tensorflow Plugin!");
@@ -205,23 +206,33 @@ void TensorflowPlugin::copyInputBuffers(jsi::Runtime& runtime, jsi::Object input
205206
#if DEBUG
206207
if (!inputValues.isArray(runtime)) {
207208
[[unlikely]];
208-
throw std::runtime_error(
209-
"TFLite: Input Values must be an array, one item for each input tensor!");
209+
throw jsi::JSError(runtime,
210+
"TFLite: Input Values must be an array, one item for each input tensor!");
210211
}
211212
#endif
212213

213214
jsi::Array array = inputValues.asArray(runtime);
214215
size_t count = array.size(runtime);
215216
if (count != TfLiteInterpreterGetInputTensorCount(_interpreter)) {
216217
[[unlikely]];
217-
throw std::runtime_error(
218-
"TFLite: Input Values have different size than there are input tensors!");
218+
throw jsi::JSError(runtime,
219+
"TFLite: Input Values have different size than there are input tensors!");
219220
}
220221

221222
for (size_t i = 0; i < count; i++) {
222223
TfLiteTensor* tensor = TfLiteInterpreterGetInputTensor(_interpreter, i);
223-
jsi::Value value = array.getValueAtIndex(runtime, i);
224-
TypedArrayBase inputBuffer = getTypedArray(runtime, value.asObject(runtime));
224+
jsi::Object object = array.getValueAtIndex(runtime, i).asObject(runtime);
225+
226+
#if DEBUG
227+
if (!isTypedArray(runtime, object)) {
228+
[[unlikely]];
229+
throw jsi::JSError(
230+
runtime,
231+
"TFLite: Input value is not a TypedArray! (Uint8Array, Uint16Array, Float32Array, etc.)");
232+
}
233+
#endif
234+
235+
TypedArrayBase inputBuffer = getTypedArray(runtime, std::move(object));
225236
TensorHelpers::updateTensorFromJSBuffer(runtime, tensor, inputBuffer);
226237
}
227238
}
@@ -244,7 +255,8 @@ void TensorflowPlugin::run() {
244255
TfLiteStatus status = TfLiteInterpreterInvoke(_interpreter);
245256
if (status != kTfLiteOk) {
246257
[[unlikely]];
247-
throw std::runtime_error("Failed to run TFLite Model! Status: " + tfLiteStatusToString(status));
258+
throw std::runtime_error("TFLite: Failed to run TFLite Model! Status: " +
259+
tfLiteStatusToString(status));
248260
}
249261
}
250262

@@ -296,7 +308,8 @@ jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& p
296308
TfLiteTensor* tensor = TfLiteInterpreterGetInputTensor(_interpreter, i);
297309
if (tensor == nullptr) {
298310
[[unlikely]];
299-
throw jsi::JSError(runtime, "Failed to get input tensor " + std::to_string(i) + "!");
311+
throw jsi::JSError(runtime,
312+
"TFLite: Failed to get input tensor " + std::to_string(i) + "!");
300313
}
301314

302315
jsi::Object object = TensorHelpers::tensorToJSObject(runtime, tensor);
@@ -310,7 +323,8 @@ jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& p
310323
const TfLiteTensor* tensor = TfLiteInterpreterGetOutputTensor(_interpreter, i);
311324
if (tensor == nullptr) {
312325
[[unlikely]];
313-
throw jsi::JSError(runtime, "Failed to get output tensor " + std::to_string(i) + "!");
326+
throw jsi::JSError(runtime,
327+
"TFLite: Failed to get output tensor " + std::to_string(i) + "!");
314328
}
315329

316330
jsi::Object object = TensorHelpers::tensorToJSObject(runtime, tensor);

src/TensorflowLite.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ export function loadTensorflowModel(
129129
uri = source.url
130130
} else {
131131
throw new Error(
132-
'Invalid source passed! Source should be either a React Native require(..) or a `{ url: string }` object!'
132+
'TFLite: Invalid source passed! Source should be either a React Native require(..) or a `{ url: string }` object!'
133133
)
134134
}
135135
return global.__loadTensorflowModel(uri, delegate)

0 commit comments

Comments
 (0)