@@ -94,13 +94,11 @@ Expected<TensorBuffer> CompiledModel::CreateInputOutputBuffer(
9494 is_input ? model_.GetInputTensorType (signature_index, tensor_name)
9595 : model_.GetOutputTensorType (signature_index, tensor_name);
9696 LITERT_ASSIGN_OR_RETURN (RankedTensorType tensor_type, tensor_type_expected);
97- Expected<TensorBufferRequirements> buffer_requirements_expected =
98- is_input ? GetInputBufferRequirements (signature_index, tensor_name)
99- : GetOutputBufferRequirements (signature_index, tensor_name);
100-
101- LITERT_ASSIGN_OR_RETURN (const TensorBufferRequirements& buffer_requirements,
102- buffer_requirements_expected);
97+ LITERT_ASSIGN_OR_RETURN (auto env, GetEnvironment ());
10398 if (is_input) {
99+ LITERT_ASSIGN_OR_RETURN (
100+ TensorBufferRequirements buffer_requirements,
101+ GetInputBufferRequirements (signature_index, tensor_name));
104102 LITERT_ASSIGN_OR_RETURN (size_t tensor_index,
105103 FindInputIndex (signature_index, tensor_name));
106104 LiteRtLayout input_layout;
@@ -111,20 +109,28 @@ Expected<TensorBuffer> CompiledModel::CreateInputOutputBuffer(
111109 tensor_type = RankedTensorType (tensor_type.ElementType (),
112110 std::move (runtime_layout));
113111 }
112+ return CreateBufferImpl (env, buffer_requirements, tensor_type);
114113 } else {
115114 const auto & dims = tensor_type.Layout ().Dimensions ();
116115 if (absl::c_find (dims, -1 ) != dims.end ()) {
117116 LITERT_ASSIGN_OR_RETURN (size_t tensor_index,
118117 FindOutputIndex (signature_index, tensor_name));
119118 LITERT_ASSIGN_OR_RETURN (
120- std::vector<Layout> output_layouts ,
119+ std::vector<Layout> runtime_layouts ,
121120 GetOutputTensorLayouts (signature_index, /* update_allocation=*/ true ));
122121 tensor_type = RankedTensorType (tensor_type.ElementType (),
123- std::move (output_layouts[tensor_index]));
122+ std::move (runtime_layouts[tensor_index]));
123+ LITERT_ASSIGN_OR_RETURN (
124+ const TensorBufferRequirements& refreshed_requirements,
125+ GetOutputBufferRequirements (signature_index, tensor_name));
126+ return CreateBufferImpl (env, refreshed_requirements, tensor_type);
127+ } else {
128+ LITERT_ASSIGN_OR_RETURN (
129+ const TensorBufferRequirements& requirements,
130+ GetOutputBufferRequirements (signature_index, tensor_name));
131+ return CreateBufferImpl (env, requirements, tensor_type);
124132 }
125133 }
126- LITERT_ASSIGN_OR_RETURN (auto env, GetEnvironment ());
127- return CreateBufferImpl (env, buffer_requirements, tensor_type);
128134}
129135
130136Expected<std::vector<TensorBuffer>> CompiledModel::CreateInputOutputBuffers (
0 commit comments