Skip to content

Commit 9e0ac71

Browse files
committed
feat: Enable non-transposed F32 reorders
Resolves [MLINFSW-1095] - Enable 4 interleaved non-transposed F32 reorders - Enable 8 interleaved non-transposed F32 reorders - Minor fixes to enable non-transposed reorders - Refactor tests to be able to test reorders - Add non-transposed reorder tests Change-Id: I674d592669d5217570c111e486236e8537832c18 Signed-off-by: Ryo Suzuki <[email protected]> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/14222 Comments-Addressed: Arm Jenkins <[email protected]> Benchmark: Arm Jenkins <[email protected]> Reviewed-by: Gunes Bayir <[email protected]> Tested-by: Arm Jenkins <[email protected]>
1 parent fcd1b0b commit 9e0ac71

File tree

7 files changed

+120
-100
lines changed

7 files changed

+120
-100
lines changed

src/core/NEON/kernels/NEReorderKernel.cpp

Lines changed: 72 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ struct TransformParams
6060

6161
std::map<TransformParams, void (*)(float *, const float *, int, int, int, int, int)> supported_float_transforms = {
6262
{{4, 1, true, arm_gemm::VLType::None}, &arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None, float, float>},
63+
{{4, 1, false, arm_gemm::VLType::None}, &arm_gemm::Transform<4, 1, false, arm_gemm::VLType::None, float, float>},
64+
{{8, 1, false, arm_gemm::VLType::None}, &arm_gemm::Transform<8, 1, false, arm_gemm::VLType::None, float, float>},
6365
#ifdef ARM_COMPUTE_ENABLE_SVE
66+
// When there is an asm kernel, use formula in transform.cpp to get the interleave_by_ number
6467
{{1, 1, true, arm_gemm::VLType::SVE}, &arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE, float, float>},
6568
#endif // ARM_COMPUTE_ENABLE_SVE
6669
};
@@ -72,6 +75,17 @@ std::map<TransformParams, void (*)(bfloat16 *, const float *, int, int, int, int
7275
#endif // ARM_COMPUTE_ENABLE_SVE
7376
};
7477

78+
#ifdef ARM_COMPUTE_ENABLE_SVE
79+
80+
// Calculate the interleave_by parameter needed for SVE kernels
81+
// using the formula listed in transform.cpp
82+
template <typename TOut>
83+
inline int get_sve_interleave_by(int interleave_by, int block_by)
84+
{
85+
return interleave_by / (get_vector_length<TOut>() / block_by);
86+
}
87+
#endif // ARM_COMPUTE_ENABLE_SVE
88+
7589
} // namespace
7690

7791
void NEReorderKernel::run(const Window &window, const ThreadInfo &info)
@@ -84,7 +98,7 @@ void NEReorderKernel::run(const Window &window, const ThreadInfo &info)
8498
const int jump_rows = ksize_rows_elements * window.x().start();
8599
const int k_start = window.x().start() * _ksize;
86100
const int k_end = std::min(window.x().end() * _ksize, _kmax);
87-
const int stride = _kmax;
101+
const int stride = _transpose ? _kmax : _xmax;
88102
const int block_by = arm_compute::block_by(_output_wf);
89103
const int interleave_by = arm_compute::interleave_by(_output_wf);
90104
ARM_COMPUTE_ERROR_ON(interleave_by != 4 && interleave_by != 8);
@@ -96,22 +110,46 @@ void NEReorderKernel::run(const Window &window, const ThreadInfo &info)
96110
{
97111
case DataType::F32:
98112
{
99-
// Interleave_by is different for SVE cases. Refer to src/core/NEON/kernels/arm_gemm/transform.cpp
100-
const int interleave_by_ = interleave_by == 8 ? interleave_by / (8 / block_by) : 4;
101-
supported_float_transforms[{interleave_by_, block_by, _transpose,
102-
interleave_by == 8 ? arm_gemm::VLType::SVE : arm_gemm::VLType::None}](
103-
reinterpret_cast<float *>(_output->buffer()) + jump_rows, reinterpret_cast<float *>(_input->buffer()),
104-
stride, k_start, k_end, 0, _xmax);
113+
void (*transform_func)(float *, const float *, int, int, int, int, int) = nullptr;
114+
#ifdef ARM_COMPUTE_ENABLE_SVE
115+
if (CPUInfo::get().has_sve())
116+
{
117+
TransformParams tparams = {get_sve_interleave_by<float>(interleave_by, block_by), block_by, _transpose,
118+
arm_gemm::VLType::SVE};
119+
if (supported_float_transforms.count(tparams))
120+
{
121+
transform_func = supported_float_transforms[tparams];
122+
}
123+
}
124+
#endif // ARM_COMPUTE_ENABLE_SVE
125+
if (transform_func == nullptr)
126+
{
127+
transform_func =
128+
supported_float_transforms[{interleave_by, block_by, _transpose, arm_gemm::VLType::None}];
129+
}
130+
transform_func(reinterpret_cast<float *>(_output->buffer()) + jump_rows,
131+
reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
105132
break;
106133
}
107134
case DataType::BFLOAT16:
108135
{
109-
// Interleave_by is different for SVE cases. Refer to transform.cpp
110-
const int interleave_by_ = interleave_by == 8 ? interleave_by / (16 / block_by) : 4;
111-
supported_bf16_transforms[{interleave_by_, block_by, _transpose,
112-
interleave_by == 8 ? arm_gemm::VLType::SVE : arm_gemm::VLType::None}](
113-
reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows,
114-
reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
136+
void (*transform_func)(bfloat16 *, const float *, int, int, int, int, int) = nullptr;
137+
#ifdef ARM_COMPUTE_ENABLE_SVE
138+
if (CPUInfo::get().has_sve())
139+
{
140+
TransformParams tparams = {get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
141+
_transpose, arm_gemm::VLType::SVE};
142+
if (supported_bf16_transforms.count(tparams))
143+
transform_func = supported_bf16_transforms[tparams];
144+
}
145+
#endif // ARM_COMPUTE_ENABLE_SVE
146+
if (transform_func == nullptr)
147+
{
148+
transform_func =
149+
supported_bf16_transforms[{interleave_by, block_by, _transpose, arm_gemm::VLType::None}];
150+
}
151+
transform_func(reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows,
152+
reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
115153
break;
116154
}
117155
default:
@@ -237,52 +275,38 @@ Status NEReorderKernel::validate(const ITensorInfo *input,
237275
int interleave_by = arm_compute::interleave_by(output_wf);
238276
int block_by = arm_compute::block_by(output_wf);
239277
ARM_COMPUTE_RETURN_ERROR_ON(interleave_by != 4 && interleave_by != 8);
240-
if (interleave_by == 8)
241-
{
242-
#ifdef ARM_COMPUTE_ENABLE_SVE
243-
ARM_COMPUTE_RETURN_ERROR_ON(!Scheduler::get().cpu_info().has_sve() ||
244-
arm_gemm::utils::get_vector_length<float>() != 8);
245-
#else // ARM_COMPUTE_ENABLE_SVE
246-
ARM_COMPUTE_RETURN_ERROR_MSG("SVE format requested on non-SVE machine");
247-
#endif // ARM_COMPUTE_ENABLE_SVE
248-
}
249278
ksize = interleave_by;
250279

251-
if (transpose)
252-
{
253-
// output k_dim needs to be same as input but multiple of ksize
254-
int32_t rnd_up_input_kdim = arm_compute::ceil_to_multiple<int32_t, int32_t>(input_k_dim, ksize);
255-
ARM_COMPUTE_RETURN_ERROR_ON(rnd_up_input_kdim != output_k_dim);
256-
// output x_dim needs to be same as input
257-
ARM_COMPUTE_RETURN_ERROR_ON(input_x_dim != output_x_dim);
258-
}
259-
else
260-
{
261-
// output x_dim needs to be same as input but multiple of ksize
262-
int32_t rnd_up_input_xdim = arm_compute::ceil_to_multiple<int32_t, int32_t>(input_x_dim, ksize);
263-
ARM_COMPUTE_RETURN_ERROR_ON(rnd_up_input_xdim != output_x_dim);
264-
// output k_dim needs to be same as input
265-
ARM_COMPUTE_RETURN_ERROR_ON(input_k_dim != output_k_dim);
266-
}
280+
// output k_dim needs to be same as input but multiple of ksize
281+
int32_t rnd_up_input_kdim = arm_compute::ceil_to_multiple<int32_t, int32_t>(input_k_dim, ksize);
282+
ARM_COMPUTE_RETURN_ERROR_ON(rnd_up_input_kdim != output_k_dim);
283+
// output x_dim needs to be same as input
284+
ARM_COMPUTE_RETURN_ERROR_ON(input_x_dim != output_x_dim);
267285

268286
switch (output->data_type())
269287
{
270288
case DataType::F32:
271289
{
272-
// Interleave_by is different for SVE cases. Refer to transform.cpp
273-
const int interleave_by_ = interleave_by == 8 ? interleave_by / (8 / block_by) : 4;
274-
ARM_COMPUTE_RETURN_ERROR_ON(!supported_float_transforms.count(
275-
{interleave_by_, block_by, transpose,
276-
interleave_by == 8 ? arm_gemm::VLType::SVE : arm_gemm::VLType::None}));
290+
#ifdef ARM_COMPUTE_ENABLE_SVE
291+
if (CPUInfo::get().has_sve() &&
292+
supported_float_transforms.count({get_sve_interleave_by<float>(interleave_by, block_by), block_by,
293+
transpose, arm_gemm::VLType::SVE}))
294+
break;
295+
#endif // ARM_COMPUTE_ENABLE_SVE
296+
ARM_COMPUTE_RETURN_ERROR_ON(
297+
!supported_float_transforms.count({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
277298
break;
278299
}
279300
case DataType::BFLOAT16:
280301
{
281-
// Interleave_by is different for SVE cases. Refer to transform.cpp
282-
const int interleave_by_ = interleave_by == 8 ? interleave_by / (16 / block_by) : 4;
283-
ARM_COMPUTE_RETURN_ERROR_ON(!supported_bf16_transforms.count(
284-
{interleave_by_, block_by, transpose,
285-
interleave_by == 8 ? arm_gemm::VLType::SVE : arm_gemm::VLType::None}));
302+
#ifdef ARM_COMPUTE_ENABLE_SVE
303+
if (CPUInfo::get().has_sve() &&
304+
supported_bf16_transforms.count({get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
305+
transpose, arm_gemm::VLType::SVE}))
306+
break;
307+
#endif // ARM_COMPUTE_ENABLE_SVE
308+
ARM_COMPUTE_RETURN_ERROR_ON(
309+
!supported_bf16_transforms.count({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
286310
break;
287311
}
288312
default:

src/core/NEON/kernels/arm_gemm/transform.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021-2024 Arm Limited.
2+
* Copyright (c) 2021-2025 Arm Limited.
33
*
44
* SPDX-License-Identifier: MIT
55
*
@@ -126,6 +126,9 @@ void Transform(
126126

127127
#include "transforms/list.hpp"
128128

129+
template void Transform<4, 1, false, VLType::None>(float *, const float *, int, int, int, int, int);
130+
template void Transform<8, 1, false, VLType::None>(float *, const float *, int, int, int, int, int);
131+
129132
// We don't have assembler transforms for AArch32, generate templated ones here.
130133
#ifdef __arm__
131134
template void Transform<8, 1, true, VLType::None>(float *, const float *, int, int, int, int, int);

tests/datasets/ReorderLayerDataset.h

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,18 @@ namespace datasets
3838
class ReorderLayerDataset
3939
{
4040
public:
41-
using type = std::tuple<TensorShape, TensorShape, WeightFormat, WeightFormat, bool>;
41+
using type = std::tuple<TensorShape, TensorShape, WeightFormat, WeightFormat>;
4242

4343
struct iterator
4444
{
4545
iterator(std::vector<TensorShape>::const_iterator in_it,
4646
std::vector<TensorShape>::const_iterator out_it,
4747
std::vector<WeightFormat>::const_iterator _wf_in_it,
48-
std::vector<WeightFormat>::const_iterator _wf_out_it,
49-
std::vector<bool>::const_iterator _transposes_it)
48+
std::vector<WeightFormat>::const_iterator _wf_out_it)
5049
: _in_it{ std::move(in_it) },
5150
_out_it{ std::move(out_it) },
5251
_wf_in_it{ std::move(_wf_in_it) },
53-
_wf_out_it{ std::move(_wf_out_it) },
54-
_transposes_it{ std::move(_transposes_it) }
52+
_wf_out_it{ std::move(_wf_out_it) }
5553
{
5654
}
5755

@@ -62,13 +60,12 @@ class ReorderLayerDataset
6260
description << "Out=" << *_out_it << ":";
6361
description << "Wf_In=" << *_wf_in_it << ":";
6462
description << "Wf_Out=" << *_wf_out_it;
65-
description << "Transpose=" << *_transposes_it;
6663
return description.str();
6764
}
6865

6966
ReorderLayerDataset::type operator*() const
7067
{
71-
return std::make_tuple(*_in_it, *_out_it, *_wf_in_it, *_wf_out_it, *_transposes_it);
68+
return std::make_tuple(*_in_it, *_out_it, *_wf_in_it, *_wf_out_it);
7269
}
7370

7471
iterator &operator++()
@@ -77,7 +74,6 @@ class ReorderLayerDataset
7774
++_out_it;
7875
++_wf_in_it;
7976
++_wf_out_it;
80-
++_transposes_it;
8177

8278
return *this;
8379
}
@@ -87,26 +83,24 @@ class ReorderLayerDataset
8783
std::vector<TensorShape>::const_iterator _out_it;
8884
std::vector<WeightFormat>::const_iterator _wf_in_it;
8985
std::vector<WeightFormat>::const_iterator _wf_out_it;
90-
std::vector<bool>::const_iterator _transposes_it;
9186
};
9287

9388
iterator begin() const
9489
{
95-
return iterator(_in_shapes.begin(), _out_shapes.begin(), _in_wfs.begin(), _out_wfs.begin(), _transposes.begin());
90+
return iterator(_in_shapes.begin(), _out_shapes.begin(), _in_wfs.begin(), _out_wfs.begin());
9691
}
9792

9893
int size() const
9994
{
100-
return std::min(_in_shapes.size(), std::min(_out_shapes.size(), std::min(_in_wfs.size(), std::min(_out_wfs.size(), _transposes.size()))));
95+
return std::min(_in_shapes.size(), std::min(_out_shapes.size(), std::min(_in_wfs.size(), _out_wfs.size())));
10196
}
10297

103-
void add_config(TensorShape in, TensorShape out, WeightFormat in_wf, WeightFormat out_wf, bool transpose)
98+
void add_config(TensorShape in, TensorShape out, WeightFormat in_wf, WeightFormat out_wf)
10499
{
105100
_in_shapes.emplace_back(std::move(in));
106101
_out_shapes.emplace_back(std::move(out));
107102
_in_wfs.emplace_back(std::move(in_wf));
108103
_out_wfs.emplace_back(std::move(out_wf));
109-
_transposes.emplace_back(transpose);
110104
}
111105

112106
// protected:
@@ -118,7 +112,6 @@ class ReorderLayerDataset
118112
std::vector<TensorShape> _out_shapes{};
119113
std::vector<WeightFormat> _in_wfs{};
120114
std::vector<WeightFormat> _out_wfs{};
121-
std::vector<bool> _transposes{};
122115
};
123116

124117
/** [ReorderLayer datasets] **/
@@ -128,16 +121,16 @@ class ReorderLayerDatasetBlock4 final : public ReorderLayerDataset
128121
public:
129122
ReorderLayerDatasetBlock4()
130123
{
131-
add_config(TensorShape(10U, 9U), TensorShape(10U, 12U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
132-
add_config(TensorShape(16U, 16U), TensorShape(16U, 16U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
133-
add_config(TensorShape(10U, 511U), TensorShape(10U, 512U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
134-
add_config(TensorShape(234U, 301U), TensorShape(234U, 304U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
135-
add_config(TensorShape(1024U, 1024U), TensorShape(1024U, 1024U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
136-
add_config(TensorShape(10U, 9U, 1U, 1U), TensorShape(10U, 12U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
137-
add_config(TensorShape(16U, 16U, 1U, 1U), TensorShape(16U, 16U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
138-
add_config(TensorShape(10U, 511U, 1U, 1U), TensorShape(10U, 512U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
139-
add_config(TensorShape(234U, 301U, 1U, 1U), TensorShape(234U, 304U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
140-
add_config(TensorShape(1024U, 1024U, 1U, 1U), TensorShape(1024U, 1024U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4, true);
124+
add_config(TensorShape(10U, 9U), TensorShape(10U, 12U), WeightFormat::OHWI, WeightFormat::OHWIo4);
125+
add_config(TensorShape(16U, 16U), TensorShape(16U, 16U), WeightFormat::OHWI, WeightFormat::OHWIo4);
126+
add_config(TensorShape(10U, 511U), TensorShape(10U, 512U), WeightFormat::OHWI, WeightFormat::OHWIo4);
127+
add_config(TensorShape(234U, 301U), TensorShape(234U, 304U), WeightFormat::OHWI, WeightFormat::OHWIo4);
128+
add_config(TensorShape(1024U, 1024U), TensorShape(1024U, 1024U), WeightFormat::OHWI, WeightFormat::OHWIo4);
129+
add_config(TensorShape(10U, 9U, 1U, 1U), TensorShape(10U, 12U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4);
130+
add_config(TensorShape(16U, 16U, 1U, 1U), TensorShape(16U, 16U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4);
131+
add_config(TensorShape(10U, 511U, 1U, 1U), TensorShape(10U, 512U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4);
132+
add_config(TensorShape(234U, 301U, 1U, 1U), TensorShape(234U, 304U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4);
133+
add_config(TensorShape(1024U, 1024U, 1U, 1U), TensorShape(1024U, 1024U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo4);
141134
}
142135
};
143136

@@ -146,16 +139,16 @@ class ReorderLayerDatasetBlock8 final : public ReorderLayerDataset
146139
public:
147140
ReorderLayerDatasetBlock8()
148141
{
149-
add_config(TensorShape(10U, 9U), TensorShape(10U, 16U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
150-
add_config(TensorShape(16U, 16U), TensorShape(16U, 16U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
151-
add_config(TensorShape(10U, 511U), TensorShape(10U, 512U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
152-
add_config(TensorShape(234U, 301U), TensorShape(234U, 304U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
153-
add_config(TensorShape(1024U, 1024U), TensorShape(1024U, 1024U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
154-
add_config(TensorShape(10U, 9U, 1U, 1U), TensorShape(10U, 16U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
155-
add_config(TensorShape(16U, 16U, 1U, 1U), TensorShape(16U, 16U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
156-
add_config(TensorShape(10U, 511U, 1U, 1U), TensorShape(10U, 512U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
157-
add_config(TensorShape(234U, 301U, 1U, 1U), TensorShape(234U, 304U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
158-
add_config(TensorShape(1024U, 1024U, 1U, 1U), TensorShape(1024U, 1024U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8, true);
142+
add_config(TensorShape(10U, 9U), TensorShape(10U, 16U), WeightFormat::OHWI, WeightFormat::OHWIo8);
143+
add_config(TensorShape(16U, 16U), TensorShape(16U, 16U), WeightFormat::OHWI, WeightFormat::OHWIo8);
144+
add_config(TensorShape(10U, 511U), TensorShape(10U, 512U), WeightFormat::OHWI, WeightFormat::OHWIo8);
145+
add_config(TensorShape(234U, 301U), TensorShape(234U, 304U), WeightFormat::OHWI, WeightFormat::OHWIo8);
146+
add_config(TensorShape(1024U, 1024U), TensorShape(1024U, 1024U), WeightFormat::OHWI, WeightFormat::OHWIo8);
147+
add_config(TensorShape(10U, 9U, 1U, 1U), TensorShape(10U, 16U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8);
148+
add_config(TensorShape(16U, 16U, 1U, 1U), TensorShape(16U, 16U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8);
149+
add_config(TensorShape(10U, 511U, 1U, 1U), TensorShape(10U, 512U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8);
150+
add_config(TensorShape(234U, 301U, 1U, 1U), TensorShape(234U, 304U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8);
151+
add_config(TensorShape(1024U, 1024U, 1U, 1U), TensorShape(1024U, 1024U, 1U, 1U), WeightFormat::OHWI, WeightFormat::OHWIo8);
159152
}
160153
};
161154

tests/validation/NEON/ReorderLayer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ DATA_TEST_CASE(ValidateReorderOHWIo8, framework::DatasetMode::ALL, combine(
8080
}
8181
}
8282

83-
FIXTURE_DATA_TEST_CASE(RunBlock8, NEReorderLayerAlias<float>, framework::DatasetMode::ALL, combine(datasets::ReorderLayerDatasetBlock8(), make("DataType", DataType::F32)))
83+
FIXTURE_DATA_TEST_CASE(RunBlock8, NEReorderLayerAlias<float>, framework::DatasetMode::ALL, combine(datasets::ReorderLayerDatasetBlock8(), make("Transpose", {true, false}), make("DataType", DataType::F32)))
8484
{
8585
// Validate output
8686
if (_hardware_supports)
@@ -90,7 +90,7 @@ FIXTURE_DATA_TEST_CASE(RunBlock8, NEReorderLayerAlias<float>, framework::Dataset
9090
}
9191
#endif // ARM_COMPUTE_ENABLE_SVE
9292

93-
FIXTURE_DATA_TEST_CASE(RunBlock4, NEReorderLayerAlias<float>, framework::DatasetMode::ALL, combine(datasets::ReorderLayerDatasetBlock4(), make("DataType", DataType::F32)))
93+
FIXTURE_DATA_TEST_CASE(RunBlock4, NEReorderLayerAlias<float>, framework::DatasetMode::ALL, combine(datasets::ReorderLayerDatasetBlock4(), make("Transpose", {true, false}), make("DataType", DataType::F32)))
9494
{
9595
// Validate output
9696
validate(Accessor(_target), _reference);

0 commit comments

Comments
 (0)