@@ -25,37 +25,52 @@ namespace neutron {
2525#define ALIGN_SIZE (size ) \
2626 ((size + BUFFER_ALIGNMENT - 1 ) & (~(BUFFER_ALIGNMENT - 1 )))
2727
28+ // clang-format off
2829/* Header schema:
29- +----------------------------------+-----------------------------------+
30- | Input TensorFormats length (1B) | Output TensorFormats length (1B) |
31- +----------------------------------+-----------------------------------+
32- | 1st input tensor format (1B) | [nth* input tensor format (1B)] |
33- +----------------------------------+-----------------------------------+
34- | 1st output tensor format (1B) | [nth* output tensor format (1B)] |
35- +----------------------------------+-----------------------------------+
30+ +----------------------------+-----------------------------+------------------------+
31+ | Neutron inputs length (1B) | Neutron outputs length (1B) | Input args length (1B) |
32+ +----------------------------+-----------+-----------------+------------------------+
33+ | 1st input tensor format (1B) | [nth* input tensor format (1B)] |
34+ +----------------------------------------+------------------------------------------+
35+ | 1st output tensor format (1B) | [nth* output tensor format (1B)] |
36+ +----------------------------------------+------------------------------------------+
37+ | 1st input map (1B) | [nth* input map (1B)] |
38+ +----------------------------------------+------------------------------------------+
39+ | 1st output map (1B) | [nth* output map (1B)] |
40+ +----------------------------------------+------------------------------------------+
3641*/
42+ // clang-format on
3743#define ITEM_SIZE 1 // 1 Byte
3844#define INPUT_TENSOR_FORMAT_LEN_POS 0
3945#define OUTPUT_TENSOR_FORMAT_LEN_POS 1
40- #define INPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) (base + 2 * ITEM_SIZE)
46+ #define INPUT_ARGS_LEN_POS 2
47+ #define INPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) (base + 3 * ITEM_SIZE)
4148#define OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) \
42- (base + 2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
43- #define PAYLOAD_ADDR (base ) \
44- (base + \
45- ALIGN_SIZE ( \
46- 2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS] + \
47- base[OUTPUT_TENSOR_FORMAT_LEN_POS]))
49+ (base + 3 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
50+ #define INPUT_TENSOR_MAP_ARRAY_ADDR (base ) \
51+ (base + 3 * ITEM_SIZE + 1 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
52+ 1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
53+ #define OUTPUT_TENSOR_MAP_ARRAY_ADDR (base ) \
54+ (base + 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
55+ 1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
56+ #define PAYLOAD_ADDR (base ) \
57+ (base + \
58+ ALIGN_SIZE ( \
59+ 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
60+ 2 * base[OUTPUT_TENSOR_FORMAT_LEN_POS]))
4861
4962// Aggregate neutron model handle and data structures into one.
5063typedef struct {
5164 int numInputs = 0 ;
5265 int numOutputs = 0 ;
53- uint32_t scratchSize = 0 ;
66+ int numInputArgs = 0 ;
5467 NeutronModelConfig mcfg;
5568 NeutronDataConfig dcfg;
5669 NeutronModelHandle nmh = NULL ;
5770 const uint8_t * inputTranspositionFlags;
5871 const uint8_t * outputTranspositionFlags;
72+ const uint8_t * inputMap;
73+ const uint8_t * outputMap;
5974} NeutronConfig;
6075
6176// Applied on outputs.
@@ -210,6 +225,15 @@ void transposeOutput(
210225 }
211226}
212227
228+ bool multipleChannelsPresent (const ArrayRef<exec_aten::SizesType>& sizes) {
229+ size_t length = sizes.size ();
230+ if (length < 3 ) {
231+ return true ;
232+ }
233+ size_t C = sizes[length - 3 ];
234+ return C != 1 ;
235+ }
236+
213237class NeutronBackend final : public PyTorchBackendInterface {
214238 public:
215239 NeutronBackend () {}
@@ -234,17 +258,19 @@ class NeutronBackend final : public PyTorchBackendInterface {
234258 // cfg->mcfg.microcode
235259 // cfg->mcfg.weights
236260 // cfg->mcfg.kernels
237- const uint8_t * transpositionFlags =
261+ const uint8_t * payloadFlags =
238262 static_cast <const uint8_t *>(processed->data ());
239- int numInputs = transpositionFlags [INPUT_TENSOR_FORMAT_LEN_POS];
240- int numOutputs = transpositionFlags [OUTPUT_TENSOR_FORMAT_LEN_POS];
241- cfg->inputTranspositionFlags =
242- INPUT_TENSOR_FORMAT_ARRAY_ADDR (transpositionFlags );
263+ uint32_t numInputs = payloadFlags [INPUT_TENSOR_FORMAT_LEN_POS];
264+ uint32_t numOutputs = payloadFlags [OUTPUT_TENSOR_FORMAT_LEN_POS];
265+ cfg->numInputArgs = payloadFlags[INPUT_ARGS_LEN_POS];
266+ cfg-> inputTranspositionFlags = INPUT_TENSOR_FORMAT_ARRAY_ADDR (payloadFlags );
243267 cfg->outputTranspositionFlags =
244- OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (transpositionFlags);
268+ OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (payloadFlags);
269+ cfg->inputMap = INPUT_TENSOR_MAP_ARRAY_ADDR (payloadFlags);
270+ cfg->outputMap = OUTPUT_TENSOR_MAP_ARRAY_ADDR (payloadFlags);
245271
246272 const uint32_t * buffer = static_cast <const uint32_t *>(
247- static_cast <const void *> PAYLOAD_ADDR (transpositionFlags ));
273+ static_cast <const void *> PAYLOAD_ADDR (payloadFlags ));
248274 uint32_t magicWord = buffer[0 ];
249275 // Check valid microcode.
250276 if (magicWord != 0x64434D6E ) {
@@ -314,39 +340,37 @@ class NeutronBackend final : public PyTorchBackendInterface {
314340 cfg->dcfg .outputs [cfg->numOutputs ] =
315341 static_cast <void *>(context.allocate (cfg->scratchSize , 16 ));
316342
317- // Set inputs and outputs from args.
343+ // Set inputs from args.
344+ // Transpose inputs if needed.
318345 for (int i = 0 ; i < cfg->numInputs ; i++) {
319- cfg->dcfg .inputs [i] = args[i]->toTensor ().const_data_ptr ();
320- }
321- for (int i = 0 ; i < cfg->numOutputs ; i++) {
322- cfg->dcfg .outputs [i] =
323- args[cfg->numInputs + i]->toTensor ().mutable_data_ptr ();
324- }
325-
326- // Transpose inputs.
327- for (int i = 0 ; i < cfg->numInputs ; i++) {
328- if (cfg->inputTranspositionFlags [i]) {
329- if (args[i]->toTensor ().sizes ().size () < 3 ) {
346+ auto arg = args[cfg->inputMap [i]]->toTensor ();
347+ if (cfg->inputTranspositionFlags [i] &&
348+ multipleChannelsPresent (arg.sizes ())) {
349+ if (arg.sizes ().size () < 3 ) {
330350 ET_LOG (Error, " Unable to transpose 1D and 2D input to channel last" );
331351 return Error::InvalidProgram;
332352 }
333353 // Allocate buffer, the allocator is reset after each PTE instruction.
334- void * buffer = context.allocate (args[i]-> toTensor () .nbytes (), 16 );
354+ void * buffer = context.allocate (arg .nbytes ());
335355 transposeInput (
336- args[i]->toTensor ().const_data_ptr (),
337- buffer,
338- args[i]->toTensor ().sizes (),
339- args[i]->toTensor ().element_size ());
356+ arg.const_data_ptr (), buffer, arg.sizes (), arg.element_size ());
340357 cfg->dcfg .inputs [i] = buffer;
358+ } else {
359+ cfg->dcfg .inputs [i] = arg.const_data_ptr ();
341360 }
342361 }
343- // Redirect outputs.
362+
363+ // Set outputs from args.
364+ // Redirect outputs if needed before transposition.
344365 for (int i = 0 ; i < cfg->numOutputs ; i++) {
345- if (cfg->outputTranspositionFlags [i]) {
366+ auto arg = args[cfg->numInputArgs + cfg->outputMap [i]]->toTensor ();
367+ if (cfg->outputTranspositionFlags [i] &&
368+ multipleChannelsPresent (arg.sizes ())) {
346369 // Allocate buffer, the allocator is reset after each PTE instruction.
347- void * buffer =
348- context.allocate (args[cfg->numInputs + i]->toTensor ().nbytes (), 16 );
370+ void * buffer = context.allocate (arg.nbytes ());
349371 cfg->dcfg .outputs [i] = buffer;
372+ } else {
373+ cfg->dcfg .outputs [i] = arg.mutable_data_ptr ();
350374 }
351375 }
352376
@@ -368,17 +392,19 @@ class NeutronBackend final : public PyTorchBackendInterface {
368392
369393 // Transpose outputs.
370394 for (int i = 0 ; i < cfg->numOutputs ; i++) {
371- if (cfg->outputTranspositionFlags [i]) {
372- if (args[cfg->numInputs + i]->toTensor ().sizes ().size () < 3 ) {
395+ auto arg = args[cfg->numInputArgs + cfg->outputMap [i]]->toTensor ();
396+ if (cfg->outputTranspositionFlags [i] &&
397+ multipleChannelsPresent (arg.sizes ())) {
398+ if (arg.sizes ().size () < 3 ) {
373399 ET_LOG (
374400 Error, " Unable to transpose 1D and 2D output to channel first" );
375401 return Error::InvalidProgram;
376402 }
377403 transposeOutput (
378404 cfg->dcfg .outputs [i],
379- args[cfg-> numInputs + i]-> toTensor () .mutable_data_ptr (),
380- args[cfg-> numInputs + i]-> toTensor () .sizes (),
381- args[cfg-> numInputs + i]-> toTensor () .element_size ());
405+ arg .mutable_data_ptr (),
406+ arg .sizes (),
407+ arg .element_size ());
382408 }
383409 }
384410
0 commit comments