Skip to content

Commit 3792211

Browse files
ai-edge-botcopybara-github
authored andcommitted
Clear the cache as shape propagation is unpredictable, and make sure the size is updated through tensor buffer requirements
And enhance the testing to run full inference LiteRT-PiperOrigin-RevId: 825765979
1 parent 9e9aab6 commit 3792211

File tree

3 files changed

+41
-20
lines changed

3 files changed

+41
-20
lines changed

litert/cc/litert_compiled_model.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,11 @@ Expected<TensorBuffer> CompiledModel::CreateInputOutputBuffer(
9494
is_input ? model_.GetInputTensorType(signature_index, tensor_name)
9595
: model_.GetOutputTensorType(signature_index, tensor_name);
9696
LITERT_ASSIGN_OR_RETURN(RankedTensorType tensor_type, tensor_type_expected);
97-
Expected<TensorBufferRequirements> buffer_requirements_expected =
98-
is_input ? GetInputBufferRequirements(signature_index, tensor_name)
99-
: GetOutputBufferRequirements(signature_index, tensor_name);
100-
101-
LITERT_ASSIGN_OR_RETURN(const TensorBufferRequirements& buffer_requirements,
102-
buffer_requirements_expected);
97+
LITERT_ASSIGN_OR_RETURN(auto env, GetEnvironment());
10398
if (is_input) {
99+
LITERT_ASSIGN_OR_RETURN(
100+
TensorBufferRequirements buffer_requirements,
101+
GetInputBufferRequirements(signature_index, tensor_name));
104102
LITERT_ASSIGN_OR_RETURN(size_t tensor_index,
105103
FindInputIndex(signature_index, tensor_name));
106104
LiteRtLayout input_layout;
@@ -111,20 +109,28 @@ Expected<TensorBuffer> CompiledModel::CreateInputOutputBuffer(
111109
tensor_type = RankedTensorType(tensor_type.ElementType(),
112110
std::move(runtime_layout));
113111
}
112+
return CreateBufferImpl(env, buffer_requirements, tensor_type);
114113
} else {
115114
const auto& dims = tensor_type.Layout().Dimensions();
116115
if (absl::c_find(dims, -1) != dims.end()) {
117116
LITERT_ASSIGN_OR_RETURN(size_t tensor_index,
118117
FindOutputIndex(signature_index, tensor_name));
119118
LITERT_ASSIGN_OR_RETURN(
120-
std::vector<Layout> output_layouts,
119+
std::vector<Layout> runtime_layouts,
121120
GetOutputTensorLayouts(signature_index, /*update_allocation=*/true));
122121
tensor_type = RankedTensorType(tensor_type.ElementType(),
123-
std::move(output_layouts[tensor_index]));
122+
std::move(runtime_layouts[tensor_index]));
123+
LITERT_ASSIGN_OR_RETURN(
124+
const TensorBufferRequirements& refreshed_requirements,
125+
GetOutputBufferRequirements(signature_index, tensor_name));
126+
return CreateBufferImpl(env, refreshed_requirements, tensor_type);
127+
} else {
128+
LITERT_ASSIGN_OR_RETURN(
129+
const TensorBufferRequirements& requirements,
130+
GetOutputBufferRequirements(signature_index, tensor_name));
131+
return CreateBufferImpl(env, requirements, tensor_type);
124132
}
125133
}
126-
LITERT_ASSIGN_OR_RETURN(auto env, GetEnvironment());
127-
return CreateBufferImpl(env, buffer_requirements, tensor_type);
128134
}
129135

130136
Expected<std::vector<TensorBuffer>> CompiledModel::CreateInputOutputBuffers(

litert/cc/litert_compiled_model_test.cc

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,33 @@ TEST(CompiledModelTest,
141141
absl::string_view signature_key = model.DefaultSignatureKey();
142142

143143
const std::vector<int> resized_dims = {4, 2, 3};
144-
LITERT_ASSERT_OK(compiled_model.ResizeInputTensor(
145-
signature_key, "arg0", absl::MakeConstSpan(resized_dims)));
144+
LITERT_ASSERT_OK_AND_ASSIGN(const auto& input_names,
145+
model.GetSignatureInputNames(signature_key));
146+
absl::flat_hash_map<absl::string_view, TensorBuffer> input_map;
147+
for (const auto& input_name : input_names) {
148+
LITERT_ASSERT_OK(compiled_model.ResizeInputTensor(
149+
signature_key, input_name, absl::MakeConstSpan(resized_dims)));
150+
LITERT_ASSERT_OK_AND_ASSIGN(
151+
TensorBuffer input_buffer,
152+
compiled_model.CreateInputBuffer(signature_key, input_name));
153+
LITERT_ASSERT_OK_AND_ASSIGN(RankedTensorType buffer_type,
154+
input_buffer.TensorType());
155+
EXPECT_THAT(buffer_type.Layout().Dimensions(),
156+
ElementsAre(resized_dims[0], resized_dims[1], resized_dims[2]));
157+
input_map[input_name] = std::move(input_buffer);
158+
}
146159

147160
LITERT_ASSERT_OK_AND_ASSIGN(
148-
TensorBuffer input_buffer,
149-
compiled_model.CreateInputBuffer(signature_key, "arg0"));
161+
TensorBuffer output_buffer,
162+
compiled_model.CreateOutputBuffer(signature_key, "tfl.add"));
150163
LITERT_ASSERT_OK_AND_ASSIGN(RankedTensorType buffer_type,
151-
input_buffer.TensorType());
164+
output_buffer.TensorType());
152165
EXPECT_THAT(buffer_type.Layout().Dimensions(),
153166
ElementsAre(resized_dims[0], resized_dims[1], resized_dims[2]));
167+
absl::flat_hash_map<absl::string_view, TensorBuffer> output_map;
168+
output_map["tfl.add"] = std::move(output_buffer);
169+
170+
LITERT_ASSERT_OK(compiled_model.Run(input_map, output_map));
154171
}
155172

156173
TEST(CompiledModelTest, BasicSignatureIndex) {

litert/runtime/compiled_model.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,11 +1451,9 @@ litert::Expected<void> LiteRtCompiledModelT::ResizeInputTensor(
14511451
"Failed to resize input tensor");
14521452
}
14531453

1454-
// Clear cached buffer requirements for this tensor
1455-
LITERT_ASSIGN_OR_RETURN(const auto tensor_id,
1456-
GetTensorIdentifier(*interp_, input_tensor));
1457-
cpu_buffer_requirements_.erase(tensor_id);
1458-
1454+
// Clear cached buffer requirements for all tensors since output and
1455+
// intermediate tensors may change shape after an explicit resize.
1456+
cpu_buffer_requirements_.clear();
14591457
return {};
14601458
}
14611459

0 commit comments

Comments
 (0)