@@ -261,12 +261,24 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
261261 event_tracer,
262262 " +EthosUBackend::execute()handles.input.permute_CHW_to_HWC()" );
263263 // permuted byte copy CHW to HWC
264+ int c, h, w;
265+ if (tensor_in.dim () == 4 ) {
266+ c = tensor_in.size (1 );
267+ h = tensor_in.size (2 );
268+ w = tensor_in.size (3 );
269+ } else if (tensor_in.dim () == 5 ) {
270+ c = tensor_in.size (2 );
271+ h = tensor_in.size (3 );
272+ w = tensor_in.size (4 );
273+ } else {
274+ ET_LOG (
275+ Error,
276+ " Unsupported input tensor dimension %d, expected 4 or 5" ,
277+ tensor_in.dim ());
278+ return Error::InvalidProgram;
279+ }
264280 permute_CHW_to_HWC (
265- tensor_in.mutable_data_ptr <char >(),
266- scratch_addr,
267- tensor_in.size (1 ),
268- tensor_in.size (2 ),
269- tensor_in.size (3 ));
281+ tensor_in.mutable_data_ptr <char >(), scratch_addr, c, h, w);
270282 } else if (both_char or both_int or both_short) {
271283 EXECUTORCH_PROF_SCOPE (
272284 event_tracer, " +EthosUBackend::execute()handles.input.memcpy()" );
@@ -364,12 +376,24 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
364376 " +EthosUBackend::execute()handles.output.permute_HWC_to_CHW()" );
365377
366378 char * output_address = (char *)output_addr;
379+ int c, h, w;
380+ if (tensor_out.dim () == 4 ) {
381+ c = tensor_out.size (1 );
382+ h = tensor_out.size (2 );
383+ w = tensor_out.size (3 );
384+ } else if (tensor_out.dim () == 5 ) {
385+ c = tensor_out.size (2 );
386+ h = tensor_out.size (3 );
387+ w = tensor_out.size (4 );
388+ } else {
389+ ET_LOG (
390+ Error,
391+ " Unsupported output tensor dimension %d, expected 4 or 5" ,
392+ tensor_out.dim ());
393+ return Error::InvalidProgram;
394+ }
367395 permute_HWC_to_CHW (
368- output_address,
369- tensor_out.mutable_data_ptr <char >(),
370- tensor_out.size (1 ),
371- tensor_out.size (2 ),
372- tensor_out.size (3 ));
396+ output_address, tensor_out.mutable_data_ptr <char >(), c, h, w);
373397 } else {
374398 EXECUTORCH_PROF_SCOPE (
375399 event_tracer, " +EthosUBackend::execute()handles.output.move()" );
@@ -430,6 +454,14 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
430454 if (permuted_shape) {
431455 ET_LOG (Debug, " Tensor input/output %d will be permuted" , index);
432456 }
457+ } else if (tensor.dim () == 5 ) {
458+ // Same as above, but for 5D tensors.
459+ permuted_shape = tensor.size (0 ) == io->shape [0 ] &&
460+ tensor.size (1 ) == io->shape [1 ] && tensor.size (2 ) == io->shape [4 ] &&
461+ tensor.size (3 ) == io->shape [2 ] && tensor.size (4 ) == io->shape [3 ];
462+ if (permuted_shape) {
463+ ET_LOG (Debug, " Tensor input/output %d will be permuted" , index);
464+ }
433465 }
434466 *is_permuted = permuted_shape;
435467 return Error::Ok;
0 commit comments