@@ -138,6 +138,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
138138 // TODO(MLETORCH-123): Optimise into direct write from Vela into the SRAM
139139 // or DRAM output for compatible data layouts.
140140 for (int i = 0 ; i < handles.inputs ->count ; i++) {
141+ auto tensor_count = 1 , io_count = 1 ;
141142 auto tensor_in = args[i]->toTensor ();
142143 char * scratch_addr = handles.scratch_data + handles.inputs ->io [i].offset ;
143144
@@ -202,6 +203,19 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
202203 ET_LOG (Error, " No matching input copy routine" );
203204 return Error::InvalidProgram;
204205 }
206+ if (!permuted_input_shape) {
207+ calculate_dimensions (
208+ tensor_in, &handles.inputs ->io [i], &tensor_count, &io_count);
209+ if (tensor_count != io_count) {
210+ ET_LOG (Error, " Input tensor sizes do not match" );
211+ ET_LOG (
212+ Error,
213+ " Program expects %d elements but got %d" ,
214+ io_count,
215+ tensor_count);
216+ return Error::InvalidProgram;
217+ }
218+ }
205219 }
206220
207221 // Allocate driver handle and synchronously invoke driver
@@ -236,14 +250,24 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
236250 result);
237251 return Error::InvalidProgram;
238252 }
239-
253+ int tensor_dim = 0 , io_dim = 0 ;
240254 // Write outputs from scratch into EValue pointers
241255 for (int i = 0 ; i < handles.outputs ->count ; i++) {
256+ int tensor_count = 1 , io_count = 1 ;
242257 const char * output_addr =
243258 handles.scratch_data + handles.outputs ->io [i].offset ;
244259 // Process input EValue into scratch
245260 // Outputs are in the index immediately after inputs
246261 auto tensor_out = args[handles.inputs ->count + i]->toTensor ();
262+
263+ calculate_dimensions (
264+ tensor_out, &handles.outputs ->io [i], &tensor_count, &io_count);
265+
266+ // At times the topological order of the outputs may change.
267+ // Lets instead ensure that the sum of dimensions match.
268+ tensor_dim = tensor_dim + tensor_count;
269+ io_dim = io_dim + io_count;
270+
247271 bool permuted_output_shape;
248272 ET_CHECK_OK_OR_RETURN_ERROR (check_requires_permute (
249273 i,
@@ -272,6 +296,12 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
272296 }
273297 }
274298 }
299+ if (tensor_dim != io_dim) {
300+ ET_LOG (Error, " Total output tensor sizes do not match" );
301+ ET_LOG (
302+ Error, " Program expects size of %d but got %d" , tensor_dim, io_dim);
303+ return Error::InvalidProgram;
304+ }
275305 return Error::Ok;
276306 }
277307
@@ -280,13 +310,29 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
280310 }
281311
282312 private:
313+ void calculate_dimensions (
314+ const executorch::aten::Tensor tensor,
315+ VelaIO* io,
316+ int * tensor_count,
317+ int * io_count) const {
318+ for (int i = 0 ; i < tensor.dim (); i++) {
319+ *tensor_count = *tensor_count * tensor.size (i);
320+ }
321+
322+ // The VelaIO type has a shape of fixed size 4
323+ for (int i = 0 ; i < 4 ; i++) {
324+ *io_count = *io_count * io->shape [i];
325+ }
326+ }
327+
283328 Error check_requires_permute (
284329 int index,
285330 const executorch::aten::Tensor tensor,
286331 VelaIO* io,
287332 bool permuted_io_flag,
288333 bool * is_permuted) const {
289334 bool permuted_shape = false ;
335+
290336 if (tensor.dim () == 4 ) {
291337 // special case for NHWC workaround in AOT; as the compilation has
292338 // permuted to channel last in an undetectable way, we assume here
@@ -304,30 +350,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
304350 return Error::InvalidProgram;
305351 }
306352 }
307- if (!permuted_shape) {
308- // Check the number of elements in each tensor match
309- int tensor_count = 1 ;
310- int io_count = 1 ;
311-
312- for (int i = 0 ; i < tensor.dim (); i++) {
313- tensor_count = tensor_count * tensor.size (i);
314- }
315-
316- // The VelaIO type has a shape of fixed size 4
317- for (int i = 0 ; i < 4 ; i++) {
318- io_count = io_count * io->shape [i];
319- }
320-
321- if (tensor_count != io_count) {
322- ET_LOG (Error, " Input tensor sizes do not match" );
323- ET_LOG (
324- Error,
325- " Program expects %d elements but got %d" ,
326- io_count,
327- tensor_count);
328- return Error::InvalidProgram;
329- }
330- }
331353 *is_permuted = permuted_shape;
332354 return Error::Ok;
333355 }
0 commit comments