Skip to content

Commit ce1a7f0

Browse files
committed
feat: Enable non-transposed BF16 reorders
- Add a new specialization of the c++ reorder template for bf16 - Enable non-transposed BF16 reorders Resolves: [MLINFSW-1092] Change-Id: Ib58cb479664b6e0579ed9297791d8c2bafb92a83 Signed-off-by: Ryo Suzuki <[email protected]> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/14537 Tested-by: Arm Jenkins <[email protected]> Reviewed-by: Gunes Bayir <[email protected]> Benchmark: Arm Jenkins <[email protected]> Comments-Addressed: Arm Jenkins <[email protected]>
1 parent eb10c12 commit ce1a7f0

File tree

12 files changed

+431
-206
lines changed

12 files changed

+431
-206
lines changed

Android.bp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ cc_library_static {
350350
"src/core/NEON/kernels/arm_gemm/quantized.cpp",
351351
"src/core/NEON/kernels/arm_gemm/rowsum_indirect_s8.cpp",
352352
"src/core/NEON/kernels/arm_gemm/rowsum_indirect_u8.cpp",
353+
"src/core/NEON/kernels/arm_gemm/transform-bf16.cpp",
353354
"src/core/NEON/kernels/arm_gemm/transform-sve.cpp",
354355
"src/core/NEON/kernels/arm_gemm/transform.cpp",
355356
"src/core/NEON/kernels/batchnormalization/impl/NEON/fp16.cpp",

filelist.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,7 @@
16531653
"src/core/NEON/kernels/arm_gemm/rowsum_indirect_s8.cpp",
16541654
"src/core/NEON/kernels/arm_gemm/rowsum_indirect_u8.cpp",
16551655
"src/core/NEON/kernels/arm_gemm/transform.cpp",
1656+
"src/core/NEON/kernels/arm_gemm/transform-bf16.cpp",
16561657
"src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_8x12/generic.cpp",
16571658
"src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp",
16581659
"src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_8x12/a55r1.cpp",

src/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ filegroup(
605605
"core/NEON/kernels/arm_gemm/quantized.cpp",
606606
"core/NEON/kernels/arm_gemm/rowsum_indirect_s8.cpp",
607607
"core/NEON/kernels/arm_gemm/rowsum_indirect_u8.cpp",
608+
"core/NEON/kernels/arm_gemm/transform-bf16.cpp",
608609
"core/NEON/kernels/arm_gemm/transform.cpp",
609610
"core/NEON/kernels/batchnormalization/impl/NEON/fp32.cpp",
610611
"core/NEON/kernels/convolution/common/padding.cpp",

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ target_sources(
598598
core/NEON/kernels/arm_gemm/quantized.cpp
599599
core/NEON/kernels/arm_gemm/rowsum_indirect_s8.cpp
600600
core/NEON/kernels/arm_gemm/rowsum_indirect_u8.cpp
601+
core/NEON/kernels/arm_gemm/transform-bf16.cpp
601602
core/NEON/kernels/arm_gemm/transform.cpp
602603
core/NEON/kernels/batchnormalization/impl/NEON/fp32.cpp
603604
core/NEON/kernels/convolution/common/padding.cpp

src/core/NEON/kernels/NEReorderKernel.cpp

Lines changed: 93 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,23 @@ std::map<TransformParams, void (*)(float *, const float *, int, int, int, int, i
6262
{{4, 1, true, arm_gemm::VLType::None}, &arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None, float, float>},
6363
{{4, 1, false, arm_gemm::VLType::None}, &arm_gemm::Transform<4, 1, false, arm_gemm::VLType::None, float, float>},
6464
{{8, 1, false, arm_gemm::VLType::None}, &arm_gemm::Transform<8, 1, false, arm_gemm::VLType::None, float, float>},
65+
{{8, 1, true, arm_gemm::VLType::None}, &arm_gemm::Transform<8, 1, true, arm_gemm::VLType::None, float, float>},
6566
#ifdef ARM_COMPUTE_ENABLE_SVE
6667
// When there is an asm kernel, use formula in transform.cpp to get the interleave_by_ number
6768
{{1, 1, true, arm_gemm::VLType::SVE}, &arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE, float, float>},
6869
#endif // ARM_COMPUTE_ENABLE_SVE
6970
};
7071

7172
std::map<TransformParams, void (*)(bfloat16 *, const float *, int, int, int, int, int)> supported_bf16_transforms = {
73+
#ifdef ARM_COMPUTE_ENABLE_BF16
7274
{{4, 4, true, arm_gemm::VLType::None}, &arm_gemm::Transform<4, 4, true, arm_gemm::VLType::None, bfloat16, float>},
75+
{{4, 4, false, arm_gemm::VLType::None}, &arm_gemm::Transform<4, 4, false, arm_gemm::VLType::None, bfloat16, float>},
76+
{{8, 4, false, arm_gemm::VLType::None}, &arm_gemm::Transform<8, 4, false, arm_gemm::VLType::None, bfloat16, float>},
77+
{{8, 4, true, arm_gemm::VLType::None}, &arm_gemm::Transform<8, 4, true, arm_gemm::VLType::None, bfloat16, float>},
7378
#ifdef ARM_COMPUTE_ENABLE_SVE
7479
{{2, 4, true, arm_gemm::VLType::SVE}, &arm_gemm::Transform<2, 4, true, arm_gemm::VLType::SVE, bfloat16, float>},
7580
#endif // ARM_COMPUTE_ENABLE_SVE
81+
#endif // ARM_COMPUTE_ENABLE_BF16
7682
};
7783

7884
#ifdef ARM_COMPUTE_ENABLE_SVE
@@ -133,23 +139,28 @@ void NEReorderKernel::run(const Window &window, const ThreadInfo &info)
133139
}
134140
case DataType::BFLOAT16:
135141
{
136-
void (*transform_func)(bfloat16 *, const float *, int, int, int, int, int) = nullptr;
137-
#ifdef ARM_COMPUTE_ENABLE_SVE
138-
if (CPUInfo::get().has_sve())
142+
if (CPUInfo::get().has_bf16())
139143
{
140-
TransformParams tparams = {get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
141-
_transpose, arm_gemm::VLType::SVE};
142-
if (supported_bf16_transforms.count(tparams))
143-
transform_func = supported_bf16_transforms[tparams];
144-
}
144+
void (*transform_func)(bfloat16 *, const float *, int, int, int, int, int) = nullptr;
145+
#ifdef ARM_COMPUTE_ENABLE_SVE
146+
if (CPUInfo::get().has_sve())
147+
{
148+
TransformParams tparams = {get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
149+
_transpose, arm_gemm::VLType::SVE};
150+
if (supported_bf16_transforms.count(tparams))
151+
transform_func = supported_bf16_transforms[tparams];
152+
}
145153
#endif // ARM_COMPUTE_ENABLE_SVE
146-
if (transform_func == nullptr)
147-
{
148-
transform_func =
149-
supported_bf16_transforms[{interleave_by, block_by, _transpose, arm_gemm::VLType::None}];
154+
if (transform_func == nullptr)
155+
{
156+
transform_func =
157+
supported_bf16_transforms[{interleave_by, block_by, _transpose, arm_gemm::VLType::None}];
158+
}
159+
transform_func(reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows,
160+
reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
161+
break;
150162
}
151-
transform_func(reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows,
152-
reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
163+
ARM_COMPUTE_ERROR("Trying to run BF16 on unsupported machine\n");
153164
break;
154165
}
155166
default:
@@ -236,84 +247,85 @@ Status NEReorderKernel::validate(const ITensorInfo *input,
236247
ARM_COMPUTE_UNUSED(input_wf);
237248
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
238249
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
239-
if (output->tensor_shape().total_size() != 0)
240-
{
241-
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32);
242-
ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != DataType::F32 && output->data_type() != DataType::BFLOAT16);
243-
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
250+
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32);
251+
ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != DataType::F32 && output->data_type() != DataType::BFLOAT16);
252+
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
244253

245-
int input_x_dim;
246-
int input_k_dim;
247-
int output_x_dim;
248-
int output_k_dim;
249-
auto dims = output->num_dimensions();
250-
switch (dims)
254+
int input_x_dim;
255+
int input_k_dim;
256+
int output_x_dim;
257+
int output_k_dim;
258+
auto dims = output->num_dimensions();
259+
switch (dims)
260+
{
261+
case 2:
251262
{
252-
case 2:
253-
{
254-
input_x_dim = input->dimension(0); // Number of columns in input matrix
255-
input_k_dim = input->dimension(1); // Number of rows in input matrix
256-
output_x_dim = output->dimension(0); // Number of columns in output matrix
257-
output_k_dim = output->dimension(1); // Number of rows in output matrix
258-
break;
259-
}
260-
case 4:
261-
{
262-
input_x_dim = input->dimension(2); // Number of columns in input matrix
263-
input_k_dim = input->dimension(3); // Number of rows in input matrix
264-
output_x_dim = output->dimension(2); // Number of columns in output matrix
265-
output_k_dim = output->dimension(3); // Number of rows in output matrix
266-
break;
267-
}
268-
default:
269-
{
270-
ARM_COMPUTE_RETURN_ERROR_MSG("Only 2 or 4 dimensions supported.");
271-
}
263+
input_x_dim = input->dimension(0); // Number of columns in input matrix
264+
input_k_dim = input->dimension(1); // Number of rows in input matrix
265+
output_x_dim = output->dimension(0); // Number of columns in output matrix
266+
output_k_dim = output->dimension(1); // Number of rows in output matrix
267+
break;
268+
}
269+
case 4:
270+
{
271+
input_x_dim = input->dimension(2); // Number of columns in input matrix
272+
input_k_dim = input->dimension(3); // Number of rows in input matrix
273+
output_x_dim = output->dimension(2); // Number of columns in output matrix
274+
output_k_dim = output->dimension(3); // Number of rows in output matrix
275+
break;
272276
}
277+
default:
278+
{
279+
ARM_COMPUTE_RETURN_ERROR_MSG("Only 2 or 4 dimensions supported.");
280+
}
281+
}
273282

274-
int ksize = 0;
275-
int interleave_by = arm_compute::interleave_by(output_wf);
276-
int block_by = arm_compute::block_by(output_wf);
277-
ARM_COMPUTE_RETURN_ERROR_ON(interleave_by != 4 && interleave_by != 8);
278-
ksize = interleave_by;
283+
int ksize = 0;
284+
int interleave_by = arm_compute::interleave_by(output_wf);
285+
int block_by = arm_compute::block_by(output_wf);
286+
ARM_COMPUTE_RETURN_ERROR_ON(interleave_by != 4 && interleave_by != 8);
287+
ksize = interleave_by;
279288

280-
// output k_dim needs to be same as input but multiple of ksize
281-
int32_t rnd_up_input_kdim = arm_compute::ceil_to_multiple<int32_t, int32_t>(input_k_dim, ksize);
282-
ARM_COMPUTE_RETURN_ERROR_ON(rnd_up_input_kdim != output_k_dim);
283-
// output x_dim needs to be same as input
284-
ARM_COMPUTE_RETURN_ERROR_ON(input_x_dim != output_x_dim);
289+
// output x_dim needs to be same as input but multiple of block_by
290+
int32_t rnd_up_input_xdim = arm_compute::ceil_to_multiple<int32_t, int32_t>(input_x_dim, block_by);
291+
ARM_COMPUTE_RETURN_ERROR_ON(rnd_up_input_xdim != output_x_dim);
292+
// output k_dim needs to be same as input but multiple of ksize
293+
int32_t rnd_up_input_kdim = arm_compute::ceil_to_multiple<int32_t, int32_t>(input_k_dim, ksize);
294+
ARM_COMPUTE_RETURN_ERROR_ON(rnd_up_input_kdim != output_k_dim);
295+
// output x_dim needs to be same as input
296+
ARM_COMPUTE_RETURN_ERROR_ON(input_x_dim != output_x_dim);
285297

286-
switch (output->data_type())
298+
switch (output->data_type())
299+
{
300+
case DataType::F32:
287301
{
288-
case DataType::F32:
289-
{
290302
#ifdef ARM_COMPUTE_ENABLE_SVE
291-
if (CPUInfo::get().has_sve() &&
292-
supported_float_transforms.count({get_sve_interleave_by<float>(interleave_by, block_by), block_by,
293-
transpose, arm_gemm::VLType::SVE}))
294-
break;
295-
#endif // ARM_COMPUTE_ENABLE_SVE
296-
ARM_COMPUTE_RETURN_ERROR_ON(
297-
!supported_float_transforms.count({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
303+
if (CPUInfo::get().has_sve() &&
304+
supported_float_transforms.count({get_sve_interleave_by<float>(interleave_by, block_by), block_by,
305+
transpose, arm_gemm::VLType::SVE}))
298306
break;
299-
}
300-
case DataType::BFLOAT16:
301-
{
302-
#ifdef ARM_COMPUTE_ENABLE_SVE
303-
if (CPUInfo::get().has_sve() &&
304-
supported_bf16_transforms.count({get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
305-
transpose, arm_gemm::VLType::SVE}))
306-
break;
307307
#endif // ARM_COMPUTE_ENABLE_SVE
308-
ARM_COMPUTE_RETURN_ERROR_ON(
309-
!supported_bf16_transforms.count({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
310-
break;
311-
}
312-
default:
313-
{
314-
ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported output data type");
308+
ARM_COMPUTE_RETURN_ERROR_ON(
309+
!supported_float_transforms.count({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
310+
break;
311+
}
312+
case DataType::BFLOAT16:
313+
{
314+
ARM_COMPUTE_ERROR_ON(!CPUInfo::get().has_bf16());
315+
#ifdef ARM_COMPUTE_ENABLE_SVE
316+
if (CPUInfo::get().has_sve() &&
317+
supported_bf16_transforms.count({get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
318+
transpose, arm_gemm::VLType::SVE}))
315319
break;
316-
}
320+
#endif // ARM_COMPUTE_ENABLE_SVE
321+
ARM_COMPUTE_RETURN_ERROR_ON(
322+
!supported_bf16_transforms.count({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
323+
break;
324+
}
325+
default:
326+
{
327+
ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported output data type");
328+
break;
317329
}
318330
}
319331
return Status{};

0 commit comments

Comments
 (0)