@@ -277,12 +277,11 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
277277 event_tracer,
278278 " +EthosUBackend::execute()handles.input.permute_CHW_to_HWC()" );
279279 // permuted byte copy CHW to HWC
280+ int c, h, w;
281+ ET_CHECK_OK_OR_RETURN_ERROR (get_chw (tensor_in, &c, &h, &w));
282+
280283 permute_CHW_to_HWC (
281- tensor_in.mutable_data_ptr <char >(),
282- scratch_addr,
283- tensor_in.size (1 ),
284- tensor_in.size (2 ),
285- tensor_in.size (3 ));
284+ tensor_in.mutable_data_ptr <char >(), scratch_addr, c, h, w);
286285 } else if (both_char || both_int || both_short || both_bool) {
287286 EXECUTORCH_PROF_SCOPE (
288287 event_tracer, " +EthosUBackend::execute()handles.input.memcpy()" );
@@ -379,13 +378,11 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
379378 " +EthosUBackend::execute()handles.output.permute_HWC_to_CHW()" );
380379
381380 const char * output_address = static_cast <const char *>(output_addr);
381+ int c, h, w;
382+ ET_CHECK_OK_OR_RETURN_ERROR (get_chw (tensor_out, &c, &h, &w));
382383
383384 permute_HWC_to_CHW (
384- output_address,
385- tensor_out.mutable_data_ptr <char >(),
386- tensor_out.size (1 ),
387- tensor_out.size (2 ),
388- tensor_out.size (3 ));
385+ output_address, tensor_out.mutable_data_ptr <char >(), c, h, w);
389386 } else {
390387 EXECUTORCH_PROF_SCOPE (
391388 event_tracer, " +EthosUBackend::execute()handles.output.memcpy()" );
@@ -419,8 +416,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
419416 *tensor_count = *tensor_count * tensor.size (i);
420417 }
421418
422- // The VelaIO type has a shape of fixed size 4
423- for (int i = 0 ; i < 4 ; i++) {
419+ // The VelaIO type has a shape of fixed size 6
420+ for (int i = 0 ; i < shapeDim ; i++) {
424421 *io_count = *io_count * io->shape [i];
425422 }
426423 }
@@ -436,17 +433,46 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
436433 // special case for NHWC workaround in AOT; as the compilation has
437434 // permuted to channel last in an undetectable way, we assume here
438435 // that the application has similarly permuted any input/output tensors.
439- permuted_shape = tensor.size (0 ) == io->shape [0 ] &&
440- tensor.size (1 ) == io->shape [3 ] && tensor.size (2 ) == io->shape [1 ] &&
441- tensor.size (3 ) == io->shape [2 ];
436+ permuted_shape =
437+ tensor.size (0 ) == io->shape [0 ] * io->shape [1 ] * io->shape [2 ] &&
438+ tensor.size (1 ) == io->shape [5 ] && tensor.size (2 ) == io->shape [3 ] &&
439+ tensor.size (3 ) == io->shape [4 ];
442440 if (permuted_shape) {
443- ET_LOG (Debug, " Tensor input/output %d will be permuted" , index);
441+ ET_LOG (Debug, " 4D tensor input/output %d will be permuted" , index);
442+ }
443+ } else if (tensor.dim () == 5 ) {
444+ // tensor has format NNCHW, but the VelaIO is in NNNHWC
445+ permuted_shape = io->shape [0 ] == 1 && tensor.size (0 ) == io->shape [1 ] &&
446+ tensor.size (1 ) == io->shape [2 ] && tensor.size (2 ) == io->shape [5 ] &&
447+ tensor.size (3 ) == io->shape [3 ] && tensor.size (4 ) == io->shape [4 ];
448+ if (permuted_shape) {
449+ ET_LOG (Debug, " 5D tensor input/output %d will be permuted" , index);
444450 }
445451 }
446452 *is_permuted = permuted_shape;
447453 return Error::Ok;
448454 }
449455
456+ Error get_chw (const executorch::aten::Tensor tensor, int * c, int * h, int * w)
457+ const {
458+ if (tensor.dim () == 4 ) {
459+ *c = tensor.size (1 );
460+ *h = tensor.size (2 );
461+ *w = tensor.size (3 );
462+ } else if (tensor.dim () == 5 ) {
463+ *c = tensor.size (2 );
464+ *h = tensor.size (3 );
465+ *w = tensor.size (4 );
466+ } else {
467+ ET_LOG (
468+ Error,
469+ " Unsupported output tensor dimension %d, expected 4 or 5" ,
470+ tensor.dim ());
471+ return Error::InvalidProgram;
472+ }
473+ return Error::Ok;
474+ }
475+
450476 void permute_CHW_to_HWC (const char * input, char * output, int C, int H, int W)
451477 const {
452478 for (int i = 0 ; i != H * W; ++i) {
0 commit comments