@@ -279,12 +279,11 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
279279 event_tracer,
280280 " +EthosUBackend::execute()handles.input.permute_CHW_to_HWC()" );
281281 // permuted byte copy CHW to HWC
282+ int c, h, w;
283+ ET_CHECK_OK_OR_RETURN_ERROR (get_chw (tensor_in, &c, &h, &w));
284+
282285 permute_CHW_to_HWC (
283- tensor_in.mutable_data_ptr <char >(),
284- scratch_addr,
285- tensor_in.size (1 ),
286- tensor_in.size (2 ),
287- tensor_in.size (3 ));
286+ tensor_in.mutable_data_ptr <char >(), scratch_addr, c, h, w);
288287 } else if (both_char || both_int || both_short || both_bool) {
289288 EXECUTORCH_PROF_SCOPE (
290289 event_tracer, " +EthosUBackend::execute()handles.input.memcpy()" );
@@ -381,13 +380,11 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
381380 " +EthosUBackend::execute()handles.output.permute_HWC_to_CHW()" );
382381
383382 const char * output_address = static_cast <const char *>(output_addr);
383+ int c, h, w;
384+ ET_CHECK_OK_OR_RETURN_ERROR (get_chw (tensor_out, &c, &h, &w));
384385
385386 permute_HWC_to_CHW (
386- output_address,
387- tensor_out.mutable_data_ptr <char >(),
388- tensor_out.size (1 ),
389- tensor_out.size (2 ),
390- tensor_out.size (3 ));
387+ output_address, tensor_out.mutable_data_ptr <char >(), c, h, w);
391388 } else {
392389 EXECUTORCH_PROF_SCOPE (
393390 event_tracer, " +EthosUBackend::execute()handles.output.memcpy()" );
@@ -421,8 +418,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
421418 *tensor_count = *tensor_count * tensor.size (i);
422419 }
423420
424- // The VelaIO type has a shape of fixed size 4
425- for (int i = 0 ; i < 4 ; i++) {
421+ // The VelaIO type has a shape of fixed size 6
422+ for (int i = 0 ; i < shapeDim ; i++) {
426423 *io_count = *io_count * io->shape [i];
427424 }
428425 }
@@ -438,17 +435,46 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
438435 // special case for NHWC workaround in AOT; as the compilation has
439436 // permuted to channel last in an undetectable way, we assume here
440437 // that the application has similarly permuted any input/output tensors.
441- permuted_shape = tensor.size (0 ) == io->shape [0 ] &&
442- tensor.size (1 ) == io->shape [3 ] && tensor.size (2 ) == io->shape [1 ] &&
443- tensor.size (3 ) == io->shape [2 ];
438+ permuted_shape =
439+ tensor.size (0 ) == io->shape [0 ] * io->shape [1 ] * io->shape [2 ] &&
440+ tensor.size (1 ) == io->shape [5 ] && tensor.size (2 ) == io->shape [3 ] &&
441+ tensor.size (3 ) == io->shape [4 ];
444442 if (permuted_shape) {
445- ET_LOG (Debug, " Tensor input/output %d will be permuted" , index);
443+ ET_LOG (Debug, " 4D tensor input/output %d will be permuted" , index);
444+ }
445+ } else if (tensor.dim () == 5 ) {
446+ // tensor has format NNCHW, but the VelaIO is in NNNHWC
447+ permuted_shape = io->shape [0 ] == 1 && tensor.size (0 ) == io->shape [1 ] &&
448+ tensor.size (1 ) == io->shape [2 ] && tensor.size (2 ) == io->shape [5 ] &&
449+ tensor.size (3 ) == io->shape [3 ] && tensor.size (4 ) == io->shape [4 ];
450+ if (permuted_shape) {
451+ ET_LOG (Debug, " 5D tensor input/output %d will be permuted" , index);
446452 }
447453 }
448454 *is_permuted = permuted_shape;
449455 return Error::Ok;
450456 }
451457
458+ Error get_chw (const executorch::aten::Tensor tensor, int * c, int * h, int * w)
459+ const {
460+ if (tensor.dim () == 4 ) {
461+ *c = tensor.size (1 );
462+ *h = tensor.size (2 );
463+ *w = tensor.size (3 );
464+ } else if (tensor.dim () == 5 ) {
465+ *c = tensor.size (2 );
466+ *h = tensor.size (3 );
467+ *w = tensor.size (4 );
468+ } else {
469+ ET_LOG (
470+ Error,
471+ " Unsupported output tensor dimension %d, expected 4 or 5" ,
472+ tensor.dim ());
473+ return Error::InvalidProgram;
474+ }
475+ return Error::Ok;
476+ }
477+
452478 void permute_CHW_to_HWC (const char * input, char * output, int C, int H, int W)
453479 const {
454480 for (int i = 0 ; i != H * W; ++i) {
0 commit comments