Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tflite/core/interpreter_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,8 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
if (subgraph->SetTensorParametersReadOnly(
i, type, get_name(tensor), dims, quantization, buffer_ptr,
buffer_size, allocation_, sparsity,
/*buffer_identifier=*/tensor->buffer()) != kTfLiteOk) {
/*buffer_identifier=*/tensor->buffer(),
/*external_buffer_id=*/tensor->external_buffer()) != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Tensor %d is invalidly specified in schema.\n",
i);
Expand Down
6 changes: 5 additions & 1 deletion tflite/core/subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,7 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name, const size_t ndims,
const int* dims, TfLiteQuantization quantization, const char* buffer,
size_t bytes, const Allocation* allocation, TfLiteSparsity* sparsity,
const size_t buffer_identifier) {
const size_t buffer_identifier, const size_t external_buffer_id) {
// Ensure quantization cleanup on failure.
ScopedTfLiteQuantization scoped_quantization(&quantization);
ScopedTfLiteSparsity scoped_sparsity(sparsity);
Expand Down Expand Up @@ -1968,6 +1968,10 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
if (buffer_identifier != kTfLiteNoBufferIdentifier) {
tensor_buffer_identifiers_[tensor_index] = buffer_identifier;
}
if (external_buffer_id != kTfLiteNoBufferIdentifier &&
external_buffer_id != 0) {
tensor_external_buffer_ids_[tensor_index] = external_buffer_id;
}
return kTfLiteOk;
}

Expand Down
25 changes: 22 additions & 3 deletions tflite/core/subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,32 @@ class Subgraph {
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter. `quantization` ownership is passed to the subgraph.
// `buffer_identifier`: An optional value to identify the buffer. If set to
// a value other than kTfLiteNoBufferIdentifier, this tensor is considered a
// constant tensor shared across multiple subgraphs / interpreters.
// `external_buffer_id`: An optional value to identify the external buffer. If
// set to a value other than kTfLiteNoBufferIdentifier, this tensor is
// considered a tensor using an external buffer shared across multiple
// subgraphs / interpreters.
inline TfLiteStatus SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization,
const char* buffer, size_t bytes, const Allocation* allocation = nullptr,
TfLiteSparsity* sparsity = nullptr,
size_t buffer_identifier = kTfLiteNoBufferIdentifier) {
size_t buffer_identifier = kTfLiteNoBufferIdentifier,
size_t external_buffer_id = kTfLiteNoBufferIdentifier) {
return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
dims.data(), quantization, buffer, bytes,
allocation, sparsity, buffer_identifier);
allocation, sparsity, buffer_identifier,
external_buffer_id);
}
TfLiteStatus SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name, size_t ndims,
const int* dims, TfLiteQuantization quantization, const char* buffer,
size_t bytes, const Allocation* allocation = nullptr,
TfLiteSparsity* sparsity = nullptr,
size_t buffer_identifier = kTfLiteNoBufferIdentifier);
size_t buffer_identifier = kTfLiteNoBufferIdentifier,
size_t external_buffer_id = kTfLiteNoBufferIdentifier);

// Set description of inputs/outputs/data/fptrs for node `node_index`.
// This variant assumes an external buffer has been allocated of size
Expand Down Expand Up @@ -611,6 +621,11 @@ class Subgraph {
return tensor_buffer_identifiers_;
}

const std::unordered_map<size_t, size_t>& GetExternalTensorBufferIdentifiers()
const {
return tensor_external_buffer_ids_;
}

// Replaces the node for the given execution index with the subgraph.
//
// - The node and subgraph tensor counts must match.
Expand Down Expand Up @@ -1220,6 +1235,10 @@ class Subgraph {
// Maps tensor constant buffers used in the subgraph to a model-wide
// identifiers.
std::unordered_map<size_t, size_t> tensor_buffer_identifiers_;

// Maps tensor external buffer ids used in the subgraph to a model-wide
// identifiers.
std::unordered_map<size_t, size_t> tensor_external_buffer_ids_;
};

} // namespace tflite
Expand Down
Loading