@@ -26,36 +26,40 @@ namespace neutron {
2626 ((size + BUFFER_ALIGNMENT - 1 ) & (~(BUFFER_ALIGNMENT - 1 )))
2727
2828/* Header schema:
29- +----------------------------------+-----------------------------------+
30- | Input TensorFormats length (1B) | Output TensorFormats length (1B) |
31- +----------------------------------+-----------------------------------+
32- | 1st input tensor format (1B) | [nth* input tensor format (1B)] |
33- +----------------------------------+-----------------------------------+
34- | 1st output tensor format (1B) | [nth* output tensor format (1B)] |
35- +----------------------------------+-----------------------------------+
29+ +----------------------------+-----------------------------+------------------------+
30+ | Neutron inputs length (1B) | Neutron outputs length (1B) | Input args length (1B) |
31+ +----------------------------+-----------+-----------------+------------------------+
32+ | 1st input tensor format (1B) | [nth* input tensor format (1B)] |
33+ +----------------------------------------+------------------------------------------+
34+ | 1st output tensor format (1B) | [nth* output tensor format (1B)] |
35+ +----------------------------------------+------------------------------------------+
36+ | 1st input map (1B) | [nth* input map (1B)] |
37+ +----------------------------------------+------------------------------------------+
38+ | 1st output map (1B) | [nth* output map (1B)] |
39+ +----------------------------------------+------------------------------------------+
3640*/
3741#define ITEM_SIZE 1 // 1 Byte
3842#define INPUT_TENSOR_FORMAT_LEN_POS 0
3943#define OUTPUT_TENSOR_FORMAT_LEN_POS 1
40- #define INPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) (base + 2 * ITEM_SIZE)
41- #define OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) \
42- (base + 2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
43- #define PAYLOAD_ADDR (base ) \
44- (base + \
45- ALIGN_SIZE ( \
46- 2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS] + \
47- base[OUTPUT_TENSOR_FORMAT_LEN_POS]))
44+ #define INPUT_ARGS_LEN_POS 2
45+ #define INPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) (base + 3 * ITEM_SIZE)
46+ #define OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) (base + 3 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
47+ #define INPUT_TENSOR_MAP_ARRAY_ADDR (base ) (base + 3 * ITEM_SIZE + 1 * base[INPUT_TENSOR_FORMAT_LEN_POS] + 1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
48+ #define OUTPUT_TENSOR_MAP_ARRAY_ADDR (base ) (base + 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + 1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
49+ #define PAYLOAD_ADDR (base ) (base + ALIGN_SIZE(3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + 2 * base[OUTPUT_TENSOR_FORMAT_LEN_POS]))
4850
4951// Aggregate neutron model handle and data structures into one.
5052typedef struct {
51- int numInputs = 0 ;
52- int numOutputs = 0 ;
53- uint32_t scratchSize = 0 ;
54- NeutronModelConfig mcfg;
55- NeutronDataConfig dcfg;
56- NeutronModelHandle nmh = NULL ;
57- const uint8_t * inputTranspositionFlags;
58- const uint8_t * outputTranspositionFlags;
53+ int numInputs = 0 ;
54+ int numOutputs = 0 ;
55+ int numInputArgs = 0 ;
56+ NeutronModelConfig mcfg;
57+ NeutronDataConfig dcfg;
58+ NeutronModelHandle nmh = NULL ;
59+ const uint8_t * inputTranspositionFlags;
60+ const uint8_t * outputTranspositionFlags;
61+ const uint8_t * inputMap;
62+ const uint8_t * outputMap;
5963} NeutronConfig;
6064
6165// Applied on outputs.
@@ -210,6 +214,15 @@ void transposeOutput(
210214 }
211215}
212216
217+ bool multipleChannelsPresent (const ArrayRef<exec_aten::SizesType>& sizes) {
218+ size_t length = sizes.size ();
219+ if (length < 3 ) {
220+ return true ;
221+ }
222+ size_t C = sizes[length - 3 ];
223+ return C != 1 ;
224+ }
225+
213226class NeutronBackend final : public PyTorchBackendInterface {
214227 public:
215228 NeutronBackend () {}
@@ -234,17 +247,17 @@ class NeutronBackend final : public PyTorchBackendInterface {
234247 // cfg->mcfg.microcode
235248 // cfg->mcfg.weights
236249 // cfg->mcfg.kernels
237- const uint8_t * transpositionFlags =
238- static_cast < const uint8_t *>(processed-> data ()) ;
239- int numInputs = transpositionFlags[INPUT_TENSOR_FORMAT_LEN_POS ];
240- int numOutputs = transpositionFlags[OUTPUT_TENSOR_FORMAT_LEN_POS ];
241- cfg->inputTranspositionFlags =
242- INPUT_TENSOR_FORMAT_ARRAY_ADDR (transpositionFlags );
243- cfg->outputTranspositionFlags =
244- OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (transpositionFlags );
250+ const uint8_t * payloadFlags = static_cast < const uint8_t *>(processed-> data ());
251+ uint32_t numInputs = payloadFlags[INPUT_TENSOR_FORMAT_LEN_POS] ;
252+ uint32_t numOutputs = payloadFlags[OUTPUT_TENSOR_FORMAT_LEN_POS ];
253+ cfg-> numInputArgs = payloadFlags[INPUT_ARGS_LEN_POS ];
254+ cfg->inputTranspositionFlags = INPUT_TENSOR_FORMAT_ARRAY_ADDR (payloadFlags);
255+ cfg-> outputTranspositionFlags = OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (payloadFlags );
256+ cfg->inputMap = INPUT_TENSOR_MAP_ARRAY_ADDR (payloadFlags);
257+ cfg-> outputMap = OUTPUT_TENSOR_MAP_ARRAY_ADDR (payloadFlags );
245258
246259 const uint32_t * buffer = static_cast <const uint32_t *>(
247- static_cast <const void *> PAYLOAD_ADDR (transpositionFlags ));
260+ static_cast <const void *>PAYLOAD_ADDR (payloadFlags ));
248261 uint32_t magicWord = buffer[0 ];
249262 // Check valid microcode.
250263 if (magicWord != 0x64434D6E ) {
@@ -314,39 +327,38 @@ class NeutronBackend final : public PyTorchBackendInterface {
314327 cfg->dcfg .outputs [cfg->numOutputs ] =
315328 static_cast <void *>(context.allocate (cfg->scratchSize , 16 ));
316329
317- // Set inputs and outputs from args.
330+ // Set inputs from args.
331+ // Transpose inputs if needed.
318332 for (int i = 0 ; i < cfg->numInputs ; i++) {
319- cfg->dcfg .inputs [i] = args[i]->toTensor ().const_data_ptr ();
320- }
321- for (int i = 0 ; i < cfg->numOutputs ; i++) {
322- cfg->dcfg .outputs [i] =
323- args[cfg->numInputs + i]->toTensor ().mutable_data_ptr ();
324- }
325-
326- // Transpose inputs.
327- for (int i = 0 ; i < cfg->numInputs ; i++) {
328- if (cfg->inputTranspositionFlags [i]) {
329- if (args[i]->toTensor ().sizes ().size () < 3 ) {
333+ auto arg = args[cfg->inputMap [i]]->toTensor ();
334+ if (cfg->inputTranspositionFlags [i] && multipleChannelsPresent (arg.sizes ())) {
335+ if (arg.sizes ().size () < 3 ) {
330336 ET_LOG (Error, " Unable to transpose 1D and 2D input to channel last" );
331337 return Error::InvalidProgram;
332338 }
333339 // Allocate buffer, the allocator is reset after each PTE instruction.
334- void * buffer = context.allocate (args[i]-> toTensor () .nbytes (), 16 );
340+ void * buffer = context.allocate (arg .nbytes ());
335341 transposeInput (
336- args[i]-> toTensor () .const_data_ptr (),
342+ arg .const_data_ptr (),
337343 buffer,
338- args[i]-> toTensor () .sizes (),
339- args[i]-> toTensor () .element_size ());
344+ arg .sizes (),
345+ arg .element_size ());
340346 cfg->dcfg .inputs [i] = buffer;
347+ } else {
348+ cfg->dcfg .inputs [i] = arg.const_data_ptr ();
341349 }
342350 }
343- // Redirect outputs.
351+
352+ // Set outputs from args.
353+ // Redirect outputs if needed before transposition.
344354 for (int i = 0 ; i < cfg->numOutputs ; i++) {
345- if (cfg->outputTranspositionFlags [i]) {
355+ auto arg = args[cfg->numInputArgs + cfg->outputMap [i]]->toTensor ();
356+ if (cfg->outputTranspositionFlags [i] && multipleChannelsPresent (arg.sizes ())) {
346357 // Allocate buffer, the allocator is reset after each PTE instruction.
347- void * buffer =
348- context.allocate (args[cfg->numInputs + i]->toTensor ().nbytes (), 16 );
358+ void * buffer = context.allocate (arg.nbytes ());
349359 cfg->dcfg .outputs [i] = buffer;
360+ } else {
361+ cfg->dcfg .outputs [i] = arg.mutable_data_ptr ();
350362 }
351363 }
352364
@@ -368,17 +380,16 @@ class NeutronBackend final : public PyTorchBackendInterface {
368380
369381 // Transpose outputs.
370382 for (int i = 0 ; i < cfg->numOutputs ; i++) {
371- if ( cfg->outputTranspositionFlags [i]) {
372- if (args[ cfg->numInputs + i]-> toTensor () .sizes (). size () < 3 ) {
373- ET_LOG (
374- Error, " Unable to transpose 1D and 2D output to channel first" );
383+ auto arg = args[ cfg->numInputArgs + cfg-> outputMap [i]]-> toTensor ();
384+ if (cfg->outputTranspositionFlags [i] && multipleChannelsPresent (arg .sizes ()) ) {
385+ if (arg. sizes (). size () < 3 ) {
386+ ET_LOG ( Error, " Unable to transpose 1D and 2D output to channel first" );
375387 return Error::InvalidProgram;
376388 }
377- transposeOutput (
378- cfg->dcfg .outputs [i],
379- args[cfg->numInputs + i]->toTensor ().mutable_data_ptr (),
380- args[cfg->numInputs + i]->toTensor ().sizes (),
381- args[cfg->numInputs + i]->toTensor ().element_size ());
389+ transposeOutput (cfg->dcfg .outputs [i],
390+ arg.mutable_data_ptr (),
391+ arg.sizes (),
392+ arg.element_size ());
382393 }
383394 }
384395
0 commit comments