@@ -62,17 +62,23 @@ std::map<TransformParams, void (*)(float *, const float *, int, int, int, int, i
6262 {{4 , 1 , true , arm_gemm::VLType::None}, &arm_gemm::Transform<4 , 1 , true , arm_gemm::VLType::None, float , float >},
6363 {{4 , 1 , false , arm_gemm::VLType::None}, &arm_gemm::Transform<4 , 1 , false , arm_gemm::VLType::None, float , float >},
6464 {{8 , 1 , false , arm_gemm::VLType::None}, &arm_gemm::Transform<8 , 1 , false , arm_gemm::VLType::None, float , float >},
65+ {{8 , 1 , true , arm_gemm::VLType::None}, &arm_gemm::Transform<8 , 1 , true , arm_gemm::VLType::None, float , float >},
6566#ifdef ARM_COMPUTE_ENABLE_SVE
6667 // When there is an asm kernel, use formula in transform.cpp to get the interleave_by_ number
6768 {{1 , 1 , true , arm_gemm::VLType::SVE}, &arm_gemm::Transform<1 , 1 , true , arm_gemm::VLType::SVE, float , float >},
6869#endif // ARM_COMPUTE_ENABLE_SVE
6970};
7071
7172std::map<TransformParams, void (*)(bfloat16 *, const float *, int , int , int , int , int )> supported_bf16_transforms = {
73+ #ifdef ARM_COMPUTE_ENABLE_BF16
7274 {{4 , 4 , true , arm_gemm::VLType::None}, &arm_gemm::Transform<4 , 4 , true , arm_gemm::VLType::None, bfloat16, float >},
75+ {{4 , 4 , false , arm_gemm::VLType::None}, &arm_gemm::Transform<4 , 4 , false , arm_gemm::VLType::None, bfloat16, float >},
76+ {{8 , 4 , false , arm_gemm::VLType::None}, &arm_gemm::Transform<8 , 4 , false , arm_gemm::VLType::None, bfloat16, float >},
77+ {{8 , 4 , true , arm_gemm::VLType::None}, &arm_gemm::Transform<8 , 4 , true , arm_gemm::VLType::None, bfloat16, float >},
7378#ifdef ARM_COMPUTE_ENABLE_SVE
7479 {{2 , 4 , true , arm_gemm::VLType::SVE}, &arm_gemm::Transform<2 , 4 , true , arm_gemm::VLType::SVE, bfloat16, float >},
7580#endif // ARM_COMPUTE_ENABLE_SVE
81+ #endif // ARM_COMPUTE_ENABLE_BF16
7682};
7783
7884#ifdef ARM_COMPUTE_ENABLE_SVE
@@ -133,23 +139,28 @@ void NEReorderKernel::run(const Window &window, const ThreadInfo &info)
133139 }
134140 case DataType::BFLOAT16:
135141 {
136- void (*transform_func)(bfloat16 *, const float *, int , int , int , int , int ) = nullptr ;
137- #ifdef ARM_COMPUTE_ENABLE_SVE
138- if (CPUInfo::get ().has_sve ())
142+ if (CPUInfo::get ().has_bf16 ())
139143 {
140- TransformParams tparams = {get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
141- _transpose, arm_gemm::VLType::SVE};
142- if (supported_bf16_transforms.count (tparams))
143- transform_func = supported_bf16_transforms[tparams];
144- }
144+ void (*transform_func)(bfloat16 *, const float *, int , int , int , int , int ) = nullptr ;
145+ #ifdef ARM_COMPUTE_ENABLE_SVE
146+ if (CPUInfo::get ().has_sve ())
147+ {
148+ TransformParams tparams = {get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
149+ _transpose, arm_gemm::VLType::SVE};
150+ if (supported_bf16_transforms.count (tparams))
151+ transform_func = supported_bf16_transforms[tparams];
152+ }
145153#endif // ARM_COMPUTE_ENABLE_SVE
146- if (transform_func == nullptr )
147- {
148- transform_func =
149- supported_bf16_transforms[{interleave_by, block_by, _transpose, arm_gemm::VLType::None}];
154+ if (transform_func == nullptr )
155+ {
156+ transform_func =
157+ supported_bf16_transforms[{interleave_by, block_by, _transpose, arm_gemm::VLType::None}];
158+ }
159+ transform_func (reinterpret_cast <bfloat16 *>(_output->buffer ()) + jump_rows,
160+ reinterpret_cast <float *>(_input->buffer ()), stride, k_start, k_end, 0 , _xmax);
161+ break ;
150162 }
151- transform_func (reinterpret_cast <bfloat16 *>(_output->buffer ()) + jump_rows,
152- reinterpret_cast <float *>(_input->buffer ()), stride, k_start, k_end, 0 , _xmax);
163+ ARM_COMPUTE_ERROR (" Trying to run BF16 on unsupported machine\n " );
153164 break ;
154165 }
155166 default :
@@ -236,84 +247,85 @@ Status NEReorderKernel::validate(const ITensorInfo *input,
236247 ARM_COMPUTE_UNUSED (input_wf);
237248 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR (input, output);
238249 ARM_COMPUTE_RETURN_ERROR_ON (input->data_type () == DataType::UNKNOWN);
239- if (output->tensor_shape ().total_size () != 0 )
240- {
241- ARM_COMPUTE_RETURN_ERROR_ON (input->data_type () != DataType::F32);
242- ARM_COMPUTE_RETURN_ERROR_ON (output->data_type () != DataType::F32 && output->data_type () != DataType::BFLOAT16);
243- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO (input, output);
250+ ARM_COMPUTE_RETURN_ERROR_ON (input->data_type () != DataType::F32);
251+ ARM_COMPUTE_RETURN_ERROR_ON (output->data_type () != DataType::F32 && output->data_type () != DataType::BFLOAT16);
252+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO (input, output);
244253
245- int input_x_dim;
246- int input_k_dim;
247- int output_x_dim;
248- int output_k_dim;
249- auto dims = output->num_dimensions ();
250- switch (dims)
254+ int input_x_dim;
255+ int input_k_dim;
256+ int output_x_dim;
257+ int output_k_dim;
258+ auto dims = output->num_dimensions ();
259+ switch (dims)
260+ {
261+ case 2 :
251262 {
252- case 2 :
253- {
254- input_x_dim = input->dimension (0 ); // Number of columns in input matrix
255- input_k_dim = input->dimension (1 ); // Number of rows in input matrix
256- output_x_dim = output->dimension (0 ); // Number of columns in output matrix
257- output_k_dim = output->dimension (1 ); // Number of rows in output matrix
258- break ;
259- }
260- case 4 :
261- {
262- input_x_dim = input->dimension (2 ); // Number of columns in input matrix
263- input_k_dim = input->dimension (3 ); // Number of rows in input matrix
264- output_x_dim = output->dimension (2 ); // Number of columns in output matrix
265- output_k_dim = output->dimension (3 ); // Number of rows in output matrix
266- break ;
267- }
268- default :
269- {
270- ARM_COMPUTE_RETURN_ERROR_MSG (" Only 2 or 4 dimensions supported." );
271- }
263+ input_x_dim = input->dimension (0 ); // Number of columns in input matrix
264+ input_k_dim = input->dimension (1 ); // Number of rows in input matrix
265+ output_x_dim = output->dimension (0 ); // Number of columns in output matrix
266+ output_k_dim = output->dimension (1 ); // Number of rows in output matrix
267+ break ;
268+ }
269+ case 4 :
270+ {
271+ input_x_dim = input->dimension (2 ); // Number of columns in input matrix
272+ input_k_dim = input->dimension (3 ); // Number of rows in input matrix
273+ output_x_dim = output->dimension (2 ); // Number of columns in output matrix
274+ output_k_dim = output->dimension (3 ); // Number of rows in output matrix
275+ break ;
272276 }
277+ default :
278+ {
279+ ARM_COMPUTE_RETURN_ERROR_MSG (" Only 2 or 4 dimensions supported." );
280+ }
281+ }
273282
274- int ksize = 0 ;
275- int interleave_by = arm_compute::interleave_by (output_wf);
276- int block_by = arm_compute::block_by (output_wf);
277- ARM_COMPUTE_RETURN_ERROR_ON (interleave_by != 4 && interleave_by != 8 );
278- ksize = interleave_by;
283+ int ksize = 0 ;
284+ int interleave_by = arm_compute::interleave_by (output_wf);
285+ int block_by = arm_compute::block_by (output_wf);
286+ ARM_COMPUTE_RETURN_ERROR_ON (interleave_by != 4 && interleave_by != 8 );
287+ ksize = interleave_by;
279288
280- // output k_dim needs to be same as input but multiple of ksize
281- int32_t rnd_up_input_kdim = arm_compute::ceil_to_multiple<int32_t , int32_t >(input_k_dim, ksize);
282- ARM_COMPUTE_RETURN_ERROR_ON (rnd_up_input_kdim != output_k_dim);
283- // output x_dim needs to be same as input
284- ARM_COMPUTE_RETURN_ERROR_ON (input_x_dim != output_x_dim);
289+ // output x_dim needs to be same as input but multiple of block_by
290+ int32_t rnd_up_input_xdim = arm_compute::ceil_to_multiple<int32_t , int32_t >(input_x_dim, block_by);
291+ ARM_COMPUTE_RETURN_ERROR_ON (rnd_up_input_xdim != output_x_dim);
292+ // output k_dim needs to be same as input but multiple of ksize
293+ int32_t rnd_up_input_kdim = arm_compute::ceil_to_multiple<int32_t , int32_t >(input_k_dim, ksize);
294+ ARM_COMPUTE_RETURN_ERROR_ON (rnd_up_input_kdim != output_k_dim);
295+ // output x_dim needs to be same as input
296+ ARM_COMPUTE_RETURN_ERROR_ON (input_x_dim != output_x_dim);
285297
286- switch (output->data_type ())
298+ switch (output->data_type ())
299+ {
300+ case DataType::F32:
287301 {
288- case DataType::F32:
289- {
290302#ifdef ARM_COMPUTE_ENABLE_SVE
291- if (CPUInfo::get ().has_sve () &&
292- supported_float_transforms.count ({get_sve_interleave_by<float >(interleave_by, block_by), block_by,
293- transpose, arm_gemm::VLType::SVE}))
294- break ;
295- #endif // ARM_COMPUTE_ENABLE_SVE
296- ARM_COMPUTE_RETURN_ERROR_ON (
297- !supported_float_transforms.count ({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
303+ if (CPUInfo::get ().has_sve () &&
304+ supported_float_transforms.count ({get_sve_interleave_by<float >(interleave_by, block_by), block_by,
305+ transpose, arm_gemm::VLType::SVE}))
298306 break ;
299- }
300- case DataType::BFLOAT16:
301- {
302- #ifdef ARM_COMPUTE_ENABLE_SVE
303- if (CPUInfo::get ().has_sve () &&
304- supported_bf16_transforms.count ({get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
305- transpose, arm_gemm::VLType::SVE}))
306- break ;
307307#endif // ARM_COMPUTE_ENABLE_SVE
308- ARM_COMPUTE_RETURN_ERROR_ON (
309- !supported_bf16_transforms.count ({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
310- break ;
311- }
312- default :
313- {
314- ARM_COMPUTE_RETURN_ERROR_MSG (" Unsupported output data type" );
308+ ARM_COMPUTE_RETURN_ERROR_ON (
309+ !supported_float_transforms.count ({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
310+ break ;
311+ }
312+ case DataType::BFLOAT16:
313+ {
314+ ARM_COMPUTE_ERROR_ON (!CPUInfo::get ().has_bf16 ());
315+ #ifdef ARM_COMPUTE_ENABLE_SVE
316+ if (CPUInfo::get ().has_sve () &&
317+ supported_bf16_transforms.count ({get_sve_interleave_by<bfloat16>(interleave_by, block_by), block_by,
318+ transpose, arm_gemm::VLType::SVE}))
315319 break ;
316- }
320+ #endif // ARM_COMPUTE_ENABLE_SVE
321+ ARM_COMPUTE_RETURN_ERROR_ON (
322+ !supported_bf16_transforms.count ({interleave_by, block_by, transpose, arm_gemm::VLType::None}));
323+ break ;
324+ }
325+ default :
326+ {
327+ ARM_COMPUTE_RETURN_ERROR_MSG (" Unsupported output data type" );
328+ break ;
317329 }
318330 }
319331 return Status{};
0 commit comments