@@ -25,37 +25,53 @@ namespace neutron {
25
25
#define ALIGN_SIZE (size ) \
26
26
((size + BUFFER_ALIGNMENT - 1 ) & (~(BUFFER_ALIGNMENT - 1 )))
27
27
28
+ // clang-format off
28
29
/* Header schema:
29
- +----------------------------------+-----------------------------------+
30
- | Input TensorFormats length (1B) | Output TensorFormats length (1B) |
31
- +----------------------------------+-----------------------------------+
32
- | 1st input tensor format (1B) | [nth* input tensor format (1B)] |
33
- +----------------------------------+-----------------------------------+
34
- | 1st output tensor format (1B) | [nth* output tensor format (1B)] |
35
- +----------------------------------+-----------------------------------+
30
+ +----------------------------+-----------------------------+------------------------+
31
+ | Neutron inputs length (1B) | Neutron outputs length (1B) | Input args length (1B) |
32
+ +----------------------------+-----------+-----------------+------------------------+
33
+ | 1st input tensor format (1B) | [nth* input tensor format (1B)] |
34
+ +----------------------------------------+------------------------------------------+
35
+ | 1st output tensor format (1B) | [nth* output tensor format (1B)] |
36
+ +----------------------------------------+------------------------------------------+
37
+ | 1st input map (1B) | [nth* input map (1B)] |
38
+ +----------------------------------------+------------------------------------------+
39
+ | 1st output map (1B) | [nth* output map (1B)] |
40
+ +----------------------------------------+------------------------------------------+
36
41
*/
42
+ // clang-format on
37
43
#define ITEM_SIZE 1 // 1 Byte
38
44
#define INPUT_TENSOR_FORMAT_LEN_POS 0
39
45
#define OUTPUT_TENSOR_FORMAT_LEN_POS 1
40
- #define INPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) (base + 2 * ITEM_SIZE)
46
+ #define INPUT_ARGS_LEN_POS 2
47
+ #define INPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) (base + 3 * ITEM_SIZE)
41
48
#define OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) \
42
- (base + 2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
43
- #define PAYLOAD_ADDR (base ) \
44
- (base + \
45
- ALIGN_SIZE ( \
46
- 2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS] + \
47
- base[OUTPUT_TENSOR_FORMAT_LEN_POS]))
49
+ (base + 3 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
50
+ #define INPUT_TENSOR_MAP_ARRAY_ADDR (base ) \
51
+ (base + 3 * ITEM_SIZE + 1 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
52
+ 1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
53
+ #define OUTPUT_TENSOR_MAP_ARRAY_ADDR (base ) \
54
+ (base + 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
55
+ 1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
56
+ #define PAYLOAD_ADDR (base ) \
57
+ (base + \
58
+ ALIGN_SIZE ( \
59
+ 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
60
+ 2 * base[OUTPUT_TENSOR_FORMAT_LEN_POS]))
48
61
49
62
// Aggregate neutron model handle and data structures into one.
50
63
typedef struct {
51
64
int numInputs = 0 ;
52
65
int numOutputs = 0 ;
66
+ int numInputArgs = 0 ;
53
67
uint32_t scratchSize = 0 ;
54
68
NeutronModelConfig mcfg;
55
69
NeutronDataConfig dcfg;
56
70
NeutronModelHandle nmh = NULL ;
57
71
const uint8_t * inputTranspositionFlags;
58
72
const uint8_t * outputTranspositionFlags;
73
+ const uint8_t * inputMap;
74
+ const uint8_t * outputMap;
59
75
} NeutronConfig;
60
76
61
77
// Applied on outputs.
@@ -210,6 +226,15 @@ void transposeOutput(
210
226
}
211
227
}
212
228
229
+ bool multipleChannelsPresent (const ArrayRef<exec_aten::SizesType>& sizes) {
230
+ size_t length = sizes.size ();
231
+ if (length < 3 ) {
232
+ return true ;
233
+ }
234
+ size_t C = sizes[length - 3 ];
235
+ return C != 1 ;
236
+ }
237
+
213
238
class NeutronBackend final : public PyTorchBackendInterface {
214
239
public:
215
240
NeutronBackend () {}
@@ -234,17 +259,19 @@ class NeutronBackend final : public PyTorchBackendInterface {
234
259
// cfg->mcfg.microcode
235
260
// cfg->mcfg.weights
236
261
// cfg->mcfg.kernels
237
- const uint8_t * transpositionFlags =
262
+ const uint8_t * payloadFlags =
238
263
static_cast <const uint8_t *>(processed->data ());
239
- int numInputs = transpositionFlags [INPUT_TENSOR_FORMAT_LEN_POS];
240
- int numOutputs = transpositionFlags [OUTPUT_TENSOR_FORMAT_LEN_POS];
241
- cfg->inputTranspositionFlags =
242
- INPUT_TENSOR_FORMAT_ARRAY_ADDR (transpositionFlags );
264
+ uint32_t numInputs = payloadFlags [INPUT_TENSOR_FORMAT_LEN_POS];
265
+ uint32_t numOutputs = payloadFlags [OUTPUT_TENSOR_FORMAT_LEN_POS];
266
+ cfg->numInputArgs = payloadFlags[INPUT_ARGS_LEN_POS];
267
+ cfg-> inputTranspositionFlags = INPUT_TENSOR_FORMAT_ARRAY_ADDR (payloadFlags );
243
268
cfg->outputTranspositionFlags =
244
- OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (transpositionFlags);
269
+ OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (payloadFlags);
270
+ cfg->inputMap = INPUT_TENSOR_MAP_ARRAY_ADDR (payloadFlags);
271
+ cfg->outputMap = OUTPUT_TENSOR_MAP_ARRAY_ADDR (payloadFlags);
245
272
246
273
const uint32_t * buffer = static_cast <const uint32_t *>(
247
- static_cast <const void *> PAYLOAD_ADDR (transpositionFlags ));
274
+ static_cast <const void *> PAYLOAD_ADDR (payloadFlags ));
248
275
uint32_t magicWord = buffer[0 ];
249
276
// Check valid microcode.
250
277
if (magicWord != 0x64434D6E ) {
@@ -314,39 +341,37 @@ class NeutronBackend final : public PyTorchBackendInterface {
314
341
cfg->dcfg .outputs [cfg->numOutputs ] =
315
342
static_cast <void *>(context.allocate (cfg->scratchSize , 16 ));
316
343
317
- // Set inputs and outputs from args.
344
+ // Set inputs from args.
345
+ // Transpose inputs if needed.
318
346
for (int i = 0 ; i < cfg->numInputs ; i++) {
319
- cfg->dcfg .inputs [i] = args[i]->toTensor ().const_data_ptr ();
320
- }
321
- for (int i = 0 ; i < cfg->numOutputs ; i++) {
322
- cfg->dcfg .outputs [i] =
323
- args[cfg->numInputs + i]->toTensor ().mutable_data_ptr ();
324
- }
325
-
326
- // Transpose inputs.
327
- for (int i = 0 ; i < cfg->numInputs ; i++) {
328
- if (cfg->inputTranspositionFlags [i]) {
329
- if (args[i]->toTensor ().sizes ().size () < 3 ) {
347
+ auto arg = args[cfg->inputMap [i]]->toTensor ();
348
+ if (cfg->inputTranspositionFlags [i] &&
349
+ multipleChannelsPresent (arg.sizes ())) {
350
+ if (arg.sizes ().size () < 3 ) {
330
351
ET_LOG (Error, " Unable to transpose 1D and 2D input to channel last" );
331
352
return Error::InvalidProgram;
332
353
}
333
354
// Allocate buffer, the allocator is reset after each PTE instruction.
334
- void * buffer = context.allocate (args[i]-> toTensor () .nbytes (), 16 );
355
+ void * buffer = context.allocate (arg .nbytes ());
335
356
transposeInput (
336
- args[i]->toTensor ().const_data_ptr (),
337
- buffer,
338
- args[i]->toTensor ().sizes (),
339
- args[i]->toTensor ().element_size ());
357
+ arg.const_data_ptr (), buffer, arg.sizes (), arg.element_size ());
340
358
cfg->dcfg .inputs [i] = buffer;
359
+ } else {
360
+ cfg->dcfg .inputs [i] = arg.const_data_ptr ();
341
361
}
342
362
}
343
- // Redirect outputs.
363
+
364
+ // Set outputs from args.
365
+ // Redirect outputs if needed before transposition.
344
366
for (int i = 0 ; i < cfg->numOutputs ; i++) {
345
- if (cfg->outputTranspositionFlags [i]) {
367
+ auto arg = args[cfg->numInputArgs + cfg->outputMap [i]]->toTensor ();
368
+ if (cfg->outputTranspositionFlags [i] &&
369
+ multipleChannelsPresent (arg.sizes ())) {
346
370
// Allocate buffer, the allocator is reset after each PTE instruction.
347
- void * buffer =
348
- context.allocate (args[cfg->numInputs + i]->toTensor ().nbytes (), 16 );
371
+ void * buffer = context.allocate (arg.nbytes ());
349
372
cfg->dcfg .outputs [i] = buffer;
373
+ } else {
374
+ cfg->dcfg .outputs [i] = arg.mutable_data_ptr ();
350
375
}
351
376
}
352
377
@@ -368,17 +393,19 @@ class NeutronBackend final : public PyTorchBackendInterface {
368
393
369
394
// Transpose outputs.
370
395
for (int i = 0 ; i < cfg->numOutputs ; i++) {
371
- if (cfg->outputTranspositionFlags [i]) {
372
- if (args[cfg->numInputs + i]->toTensor ().sizes ().size () < 3 ) {
396
+ auto arg = args[cfg->numInputArgs + cfg->outputMap [i]]->toTensor ();
397
+ if (cfg->outputTranspositionFlags [i] &&
398
+ multipleChannelsPresent (arg.sizes ())) {
399
+ if (arg.sizes ().size () < 3 ) {
373
400
ET_LOG (
374
401
Error, " Unable to transpose 1D and 2D output to channel first" );
375
402
return Error::InvalidProgram;
376
403
}
377
404
transposeOutput (
378
405
cfg->dcfg .outputs [i],
379
- args[cfg-> numInputs + i]-> toTensor () .mutable_data_ptr (),
380
- args[cfg-> numInputs + i]-> toTensor () .sizes (),
381
- args[cfg-> numInputs + i]-> toTensor () .element_size ());
406
+ arg .mutable_data_ptr (),
407
+ arg .sizes (),
408
+ arg .element_size ());
382
409
}
383
410
}
384
411
0 commit comments