Skip to content

Commit f481b12

Browse files
committed
Support multiple batches (N > 1) in Arm runtime
Extends testing of multiple batches on FVP for multiple operators. Note that currently not all operators are supported on compiler level. Change-Id: I59aee4b80fd058931e02806a83ca639533e7c76b
1 parent eca5d9f commit f481b12

File tree

13 files changed

+120
-63
lines changed

13 files changed

+120
-63
lines changed

backends/arm/runtime/ArmBackendEthosU.cpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,10 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
241241
event_tracer,
242242
"+ArmBackend::execute()handles.input.permute_CHW_to_HWC()");
243243
// permuted byte copy CHW to HWC
244-
permute_CHW_to_HWC(
244+
permute_NCHW_to_NHWC(
245245
tensor_in.mutable_data_ptr<char>(),
246246
scratch_addr,
247+
tensor_in.size(0),
247248
tensor_in.size(1),
248249
tensor_in.size(2),
249250
tensor_in.size(3));
@@ -342,9 +343,10 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
342343
"+ArmBackend::execute()handles.output.permute_HWC_to_CHW()");
343344

344345
char* output_address = (char*)output_addr;
345-
permute_HWC_to_CHW(
346+
permute_NHWC_to_NCHW(
346347
output_address,
347348
tensor_out.mutable_data_ptr<char>(),
349+
tensor_out.size(0),
348350
tensor_out.size(1),
349351
tensor_out.size(2),
350352
tensor_out.size(3));
@@ -420,21 +422,53 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
420422
return Error::Ok;
421423
}
422424

423-
void permute_CHW_to_HWC(char* input, char* output, int C, int H, int W)
424-
const {
425-
for (int i = 0; i != H * W; ++i) {
426-
for (int j = 0; j < C; ++j) {
427-
output[i * C + j] = input[i + j * W * H];
425+
void permute_NCHW_to_NHWC(
426+
const char* input,
427+
char* output,
428+
const int N,
429+
const int C,
430+
const int H,
431+
const int W) const {
432+
for (int n = 0; n < N; n++) {
433+
for (int c = 0; c < C; c++) {
434+
for (int i = 0; i < H * W; i++) {
435+
*output = *input;
436+
// Next element
437+
input++;
438+
output += C;
439+
}
440+
// Rewind output and increment to next channel
441+
output -= (H * W * C);
442+
output++;
428443
}
444+
// Rewind output and increment to next batch
445+
output -= C;
446+
output += (H * W * C);
429447
}
430448
}
431449

432-
void permute_HWC_to_CHW(char* input, char* output, int C, int H, int W)
433-
const {
434-
for (int i = 0; i != H * W; ++i) {
435-
for (int j = 0; j < C; ++j) {
436-
output[i + j * W * H] = input[i * C + j];
450+
void permute_NHWC_to_NCHW(
451+
const char* input,
452+
char* output,
453+
const int N,
454+
const int C,
455+
const int H,
456+
const int W) const {
457+
for (int n = 0; n < N; n++) {
458+
for (int i = 0; i < H * W; i++) {
459+
for (int c = 0; c < C; c++) {
460+
*output = *input;
461+
// Next channel
462+
input++;
463+
output += H * W;
464+
}
465+
// Rewind output and increment to next element
466+
output -= (H * W * C);
467+
output++;
437468
}
469+
// Rewind output and increment to next batch
470+
output -= H * W;
471+
output += (H * W * C);
438472
}
439473
}
440474
};

backends/arm/test/ops/test_add.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class Add(torch.nn.Module):
2525
(torch.FloatTensor([1, 2, 3, 5, 7]),),
2626
(3 * torch.ones(8),),
2727
(10 * torch.randn(8),),
28-
(torch.ones(1, 1, 4, 4),),
29-
(torch.ones(1, 3, 4, 2),),
28+
(torch.ones(2, 1, 4, 4),),
29+
(torch.ones(2, 3, 4, 2),),
3030
]
3131

3232
def forward(self, x):
@@ -38,10 +38,10 @@ class Add2(torch.nn.Module):
3838
torch.FloatTensor([1, 2, 3, 5, 7]),
3939
(torch.FloatTensor([2, 1, 2, 1, 10])),
4040
),
41-
(torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)),
42-
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
43-
(torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
44-
(10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
41+
(torch.ones(2, 10, 4, 6), torch.ones(2, 10, 4, 6)),
42+
(torch.randn(2, 3, 4, 4), torch.randn(2, 3, 4, 4)),
43+
(torch.randn(2, 1, 4, 4), torch.ones(2, 1, 4, 1)),
44+
(10000 * torch.randn(2, 1, 4, 4), torch.randn(2, 1, 4, 1)),
4545
]
4646

4747
def __init__(self):

backends/arm/test/ops/test_avg_pool.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
("randn", torch.randn(1, 16, 50, 32), [4, 2, 0]),
3030
]
3131

32+
test_data_suite_mult_batches = [
33+
# (test_name, test_data, [kernel_size, stride, padding])
34+
("rand", torch.rand(2, 16, 50, 32), [4, 2, 0]),
35+
]
36+
3237

3338
class TestAvgPool2d(unittest.TestCase):
3439
"""Tests AvgPool2d."""
@@ -168,3 +173,31 @@ def test_avgpool2d_tosa_u85_BI(
168173
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
169174
(test_data,),
170175
)
176+
177+
@parameterized.expand(test_data_suite_mult_batches)
178+
@conftest.expectedFailureOnFVP # See MLTORCH-517
179+
def test_avgpool2d_tosa_u55_BI_mult_batches(
180+
self,
181+
test_name: str,
182+
test_data: torch.Tensor,
183+
model_params: int | Tuple[int, int],
184+
):
185+
self._test_avgpool2d_tosa_ethos_BI_pipeline(
186+
self.AvgPool2d(*model_params),
187+
common.get_u55_compile_spec(permute_memory_to_nhwc=True),
188+
(test_data,),
189+
)
190+
191+
@parameterized.expand(test_data_suite_mult_batches)
192+
@conftest.expectedFailureOnFVP # See MLTORCH-517
193+
def test_avgpool2d_tosa_u85_BI_mult_batches(
194+
self,
195+
test_name: str,
196+
test_data: torch.Tensor,
197+
model_params: int | Tuple[int, int],
198+
):
199+
self._test_avgpool2d_tosa_ethos_BI_pipeline(
200+
self.AvgPool2d(*model_params),
201+
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
202+
(test_data,),
203+
)

backends/arm/test/ops/test_batch_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# (test_name, test_data, [num_features, affine, track_running_stats, weight, bias, running_mean, running_var,] )
2323
(
2424
"zeros_affineT_runStatsT_default_weight_bias_mean_var",
25-
torch.zeros(1, 32, 112, 112),
25+
torch.zeros(2, 32, 112, 112),
2626
[
2727
32,
2828
True,

backends/arm/test/ops/test_depthwise_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@
161161
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1),
162162
("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1),
163163
("3x3_1x4x256x256_gp4_nobias", dw_conv2d_3x3_1x4x256x256_gp4_nobias),
164+
("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3),
164165
]
165166

166167
testsuite_conv2d_u85_xfails = [
167-
("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3),
168168
("two_dw_conv2d", two_dw_conv2d),
169169
]
170170

backends/arm/test/ops/test_div.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@
3232
torch.ones(5) * (-1),
3333
None,
3434
),
35-
(
36-
"op_div_rank1_rand",
37-
torch.rand(5) * 5,
38-
torch.rand(5) * 5,
39-
None,
40-
),
4135
(
4236
"op_div_rank4_ones",
4337
torch.ones(5, 10, 25, 20),
4438
torch.ones(5, 10, 25, 20),
4539
None,
4640
),
41+
(
42+
"op_div_rank1_rand",
43+
torch.rand(5) * 5,
44+
torch.rand(5) * 5,
45+
None,
46+
),
4747
(
4848
"op_div_rank4_negative_ones",
4949
(-1) * torch.ones(5, 10, 25, 20),
@@ -183,7 +183,7 @@ def test_div_tosa_BI(
183183
test_data = (input_, other_)
184184
self._test_div_tosa_BI_pipeline(self.Div(), test_data)
185185

186-
@parameterized.expand(test_data_suite[:2])
186+
@parameterized.expand(test_data_suite[:3])
187187
def test_div_u55_BI(
188188
self,
189189
test_name: str,
@@ -197,7 +197,7 @@ def test_div_u55_BI(
197197
)
198198

199199
# Numerical issues on FVP likely due to mul op, MLETORCH-521
200-
@parameterized.expand(test_data_suite[2:])
200+
@parameterized.expand(test_data_suite[3:])
201201
@conftest.expectedFailureOnFVP
202202
def test_div_u55_BI_xfails(
203203
self,
@@ -211,7 +211,7 @@ def test_div_u55_BI_xfails(
211211
self.Div(), common.get_u55_compile_spec(), test_data
212212
)
213213

214-
@parameterized.expand(test_data_suite[:2])
214+
@parameterized.expand(test_data_suite[:3])
215215
def test_div_u85_BI(
216216
self,
217217
test_name: str,
@@ -225,7 +225,7 @@ def test_div_u85_BI(
225225
)
226226

227227
# Numerical issues on FVP likely due to mul op, MLETORCH-521
228-
@parameterized.expand(test_data_suite[2:])
228+
@parameterized.expand(test_data_suite[3:])
229229
@conftest.expectedFailureOnFVP
230230
def test_div_u85_BI_xfails(
231231
self,

backends/arm/test/ops/test_exp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
test_data_suite = [
1919
# (test_name, test_data)
20-
("zeros", torch.zeros(1, 10, 10, 10)),
20+
("zeros", torch.zeros(2, 10, 10, 10)),
2121
("ones", torch.ones(10, 10, 10)),
2222
("rand", torch.rand(10, 10) - 0.5),
23-
("randn_pos", torch.randn(1, 4, 4, 4) + 10),
23+
("randn_pos", torch.randn(2, 4, 4, 4) + 10),
2424
("randn_neg", torch.randn(10) - 10),
2525
("ramp", torch.arange(-16, 16, 0.2)),
2626
]

backends/arm/test/ops/test_hardtanh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323

2424
test_data_suite = [
2525
# (test_name, test_data)
26-
("zeros", torch.zeros(1, 10, 10, 10)),
26+
("zeros", torch.zeros(2, 10, 10, 10)),
2727
("ones", torch.ones(10, 10, 10)),
2828
("rand", torch.rand(10, 10) - 0.5),
2929
("randn_pos", torch.randn(10) + 10),
30-
("randn_neg", torch.randn(10) - 10),
30+
("randn_neg", torch.randn(2, 10, 10, 10) - 10),
3131
("ramp", torch.arange(-16, 16, 0.2)),
3232
]
3333

backends/arm/test/ops/test_layer_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
test_data_suite = [
1818
# (test_name, test_data, [normalized_shape, eps, elementwise_affine, has_bias] )
19-
("randn_last_dim", torch.randn(1, 5, 5, 5), [[5]]),
20-
("rand_last_two_dims", torch.rand(1, 5, 5, 5), [[5, 5]]),
19+
("randn_last_dim", torch.randn(2, 5, 5, 5), [[5]]),
20+
("rand_last_two_dims", torch.rand(2, 5, 5, 5), [[5, 5]]),
2121
(
2222
"rand_last_two_dims_not_elementwise_affine",
2323
torch.rand(1, 5, 5, 5),

backends/arm/test/ops/test_log.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
test_data_suite = [
1919
# (test_name, test_data)
20-
("ones_rank4", torch.ones(1, 10, 10, 10)),
20+
("ones_rank4", torch.ones(2, 10, 10, 10)),
2121
("ones_rank3", torch.ones(10, 10, 10)),
2222
("rand", torch.rand(10, 10) + 0.001),
2323
("randn_pos", torch.randn(10) + 10),
24-
("randn_spread", torch.max(torch.Tensor([0.0]), torch.randn(10) * 100)),
24+
("randn_spread", torch.max(torch.Tensor([0.0]), torch.randn(2, 10, 10, 10) * 100)),
2525
("ramp", torch.arange(0.01, 20, 0.2)),
2626
]
2727

0 commit comments

Comments
 (0)