diff --git a/tflite/core/interpreter_builder.cc b/tflite/core/interpreter_builder.cc index 8b413cb2a..b8920357f 100644 --- a/tflite/core/interpreter_builder.cc +++ b/tflite/core/interpreter_builder.cc @@ -725,7 +725,8 @@ TfLiteStatus InterpreterBuilder::ParseTensors( if (subgraph->SetTensorParametersReadOnly( i, type, get_name(tensor), dims, quantization, buffer_ptr, buffer_size, allocation_, sparsity, - /*buffer_identifier=*/tensor->buffer()) != kTfLiteOk) { + /*buffer_identifier=*/tensor->buffer(), + /*external_buffer_id=*/tensor->external_buffer()) != kTfLiteOk) { TF_LITE_REPORT_ERROR(error_reporter_, "Tensor %d is invalidly specified in schema.\n", i); diff --git a/tflite/core/subgraph.cc b/tflite/core/subgraph.cc index eff761585..32273ea61 100644 --- a/tflite/core/subgraph.cc +++ b/tflite/core/subgraph.cc @@ -1917,7 +1917,7 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const size_t ndims, const int* dims, TfLiteQuantization quantization, const char* buffer, size_t bytes, const Allocation* allocation, TfLiteSparsity* sparsity, - const size_t buffer_identifier) { + const size_t buffer_identifier, const size_t external_buffer_id) { // Ensure quantization cleanup on failure. ScopedTfLiteQuantization scoped_quantization(&quantization); ScopedTfLiteSparsity scoped_sparsity(sparsity); @@ -1968,6 +1968,10 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( if (buffer_identifier != kTfLiteNoBufferIdentifier) { tensor_buffer_identifiers_[tensor_index] = buffer_identifier; } + if (external_buffer_id != kTfLiteNoBufferIdentifier && + external_buffer_id != 0) { + tensor_external_buffer_ids_[tensor_index] = external_buffer_id; + } return kTfLiteOk; } diff --git a/tflite/core/subgraph.h b/tflite/core/subgraph.h index 529a6ffd9..0beafc683 100644 --- a/tflite/core/subgraph.h +++ b/tflite/core/subgraph.h @@ -130,22 +130,32 @@ class Subgraph { // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. `quantization` ownership is passed to the subgraph. + // `buffer_identifier`: An optional value to identify the buffer. If set to + // a value other than kTfLiteNoBufferIdentifier, this tensor is considered a + // constant tensor shared across multiple subgraphs / interpreters. + // `external_buffer_id`: An optional value to identify the external buffer. If + // set to a value other than kTfLiteNoBufferIdentifier, this tensor is + // considered a tensor using an external buffer shared across multiple + // subgraphs / interpreters. inline TfLiteStatus SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantization quantization, const char* buffer, size_t bytes, const Allocation* allocation = nullptr, TfLiteSparsity* sparsity = nullptr, - size_t buffer_identifier = kTfLiteNoBufferIdentifier) { + size_t buffer_identifier = kTfLiteNoBufferIdentifier, + size_t external_buffer_id = kTfLiteNoBufferIdentifier) { return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), dims.data(), quantization, buffer, bytes, - allocation, sparsity, buffer_identifier); + allocation, sparsity, buffer_identifier, + external_buffer_id); } TfLiteStatus SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, size_t ndims, const int* dims, TfLiteQuantization quantization, const char* buffer, size_t bytes, const Allocation* allocation = nullptr, TfLiteSparsity* sparsity = nullptr, - size_t buffer_identifier = kTfLiteNoBufferIdentifier); + size_t buffer_identifier = kTfLiteNoBufferIdentifier, + size_t external_buffer_id = kTfLiteNoBufferIdentifier); // Set description of inputs/outputs/data/fptrs for node `node_index`. // This variant assumes an external buffer has been allocated of size @@ -611,6 +621,11 @@ class Subgraph { return tensor_buffer_identifiers_; } + const std::unordered_map& GetExternalTensorBufferIdentifiers() + const { + return tensor_external_buffer_ids_; + } + // Replaces the node for the given execution index with the subgraph. // // - The node and subgraph tensor counts must match. @@ -1220,6 +1235,10 @@ class Subgraph { // Maps tensor constant buffers used in the subgraph to a model-wide // identifiers. std::unordered_map tensor_buffer_identifiers_; + + // Maps tensor external buffer ids used in the subgraph to a model-wide + // identifiers. + std::unordered_map tensor_external_buffer_ids_; }; } // namespace tflite