@@ -70,7 +70,8 @@ class CpuFullyConnectedValidationGenericFixture : public framework::Fixture
7070 DataType data_type,
7171 QuantizationInfo quantization_info,
7272 ActivationLayerInfo activation_info,
73- TestType test_type)
73+ TestType test_type,
74+ bool with_bias = true )
7475 {
7576 if (std::is_same<TensorType, Tensor>::value && // Cpu
7677 data_type == DataType::F16 && !CPUInfo::get ().has_fp16 ())
@@ -92,8 +93,8 @@ class CpuFullyConnectedValidationGenericFixture : public framework::Fixture
9293
9394 _activation_info = activation_info;
9495
95- compute_target (input_shape, weights_shape, bias_shape, output_shape);
96- compute_reference (input_shape, weights_shape, bias_shape, output_shape);
96+ compute_target (input_shape, weights_shape, bias_shape, output_shape, with_bias );
97+ compute_reference (input_shape, weights_shape, bias_shape, output_shape, with_bias );
9798 }
9899
99100protected:
@@ -147,7 +148,8 @@ class CpuFullyConnectedValidationGenericFixture : public framework::Fixture
147148 void compute_target (const TensorShape &input_shape,
148149 const TensorShape &weights_shape,
149150 const TensorShape &bias_shape,
150- const TensorShape &output_shape)
151+ const TensorShape &output_shape,
152+ bool with_bias)
151153 {
152154 TensorShape reshaped_weights_shape (weights_shape);
153155
@@ -181,15 +183,16 @@ class CpuFullyConnectedValidationGenericFixture : public framework::Fixture
181183 {
182184 src[i] = create_tensor<TensorType>(input_shape, _data_type, 1 , _input_q_info);
183185 weights[i] = create_tensor<TensorType>(reshaped_weights_shape, _data_type, 1 , _weight_q_info);
184- bias[i] = create_tensor<TensorType>(bias_shape, _bias_data_type, 1 );
186+ bias[i] = with_bias ? create_tensor<TensorType>(bias_shape, _bias_data_type, 1 ) : nullptr ;
185187 dst[i] = create_tensor<TensorType>(output_shape, _data_type, 1 , _dst_q_info);
186- weights[i].info ()->set_are_values_constant (false );
188+ weights[i].info ()->set_are_values_constant (false );
187189 }
188190 tmp_weights = create_tensor<TensorType>(weights_shape, _data_type, 1 , _weight_q_info);
189191 tmp_weights.allocator ()->allocate ();
190192
191193 const bool kernel_found =
192- bool (FunctionType::has_opt_impl (computed_weight_format, src[0 ].info (), weights[0 ].info (), bias[0 ].info (),
194+ bool (FunctionType::has_opt_impl (computed_weight_format, src[0 ].info (), weights[0 ].info (),
195+ with_bias? bias[0 ].info () : nullptr ,
193196 dst[0 ].info (), fc_info, wei_info));
194197 ARM_COMPUTE_ASSERT (kernel_found);
195198 wei_info.set_weight_format (computed_weight_format);
@@ -201,36 +204,47 @@ class CpuFullyConnectedValidationGenericFixture : public framework::Fixture
201204 reordered_weights[i].info ()->set_is_resizable (true );
202205 }
203206
204- // Create and configure function.
207+ // Create, configure and validate function.
205208 FunctionType fc;
206- fc.configure (src[0 ].info (), weights[0 ].info (), bias[0 ].info (), dst[0 ].info (), fc_info, wei_info);
209+ fc.configure (src[0 ].info (), weights[0 ].info (),
210+ with_bias? bias[0 ].info () : nullptr ,
211+ dst[0 ].info (), fc_info, wei_info);
207212 auto const aux_mem_req = fc.workspace ();
208213
214+ ARM_COMPUTE_ASSERT (fc.validate (src[0 ].info (), weights[0 ].info (),
215+ with_bias? bias[0 ].info () : nullptr ,
216+ dst[0 ].info (), fc_info, wei_info));
217+
209218 for (int i = 0 ; i < _num_parallel_runs; ++i)
210219 {
211220 ARM_COMPUTE_ASSERT (src[i].info ()->is_resizable ());
212221 ARM_COMPUTE_ASSERT (weights[i].info ()->is_resizable ());
213222 ARM_COMPUTE_ASSERT (reordered_weights[i].info ()->is_resizable ());
214- ARM_COMPUTE_ASSERT (bias[i].info ()->is_resizable ());
215223 ARM_COMPUTE_ASSERT (dst[i].info ()->is_resizable ());
216224
217225 // Allocate tensors
218226 src[i].allocator ()->allocate ();
219227 weights[i].allocator ()->allocate ();
220228 reordered_weights[i].allocator ()->allocate ();
221- bias[i].allocator ()->allocate ();
222229 dst[i].allocator ()->allocate ();
223230
224231 ARM_COMPUTE_ASSERT (!src[i].info ()->is_resizable ());
225232 ARM_COMPUTE_ASSERT (!weights[i].info ()->is_resizable ());
226233 ARM_COMPUTE_ASSERT (!reordered_weights[i].info ()->is_resizable ());
227- ARM_COMPUTE_ASSERT (!bias[i].info ()->is_resizable ());
228234 ARM_COMPUTE_ASSERT (!dst[i].info ()->is_resizable ());
229235
230236 // Fill tensors
231237 fill (AccessorType (src[i]), 0 + i * 3 );
232238 fill (AccessorType (tmp_weights), 1 + i * 3 );
233- fill (AccessorType (bias[i]), 2 + i * 3 );
239+
240+ // Handle optional bias
241+ if (with_bias)
242+ {
243+ ARM_COMPUTE_ASSERT (bias[i].info ()->is_resizable ());
244+ bias[i].allocator ()->allocate ();
245+ ARM_COMPUTE_ASSERT (!bias[i].info ()->is_resizable ());
246+ fill (AccessorType (bias[i]), 2 + i * 3 );
247+ }
234248
235249 // Reorder weight to the expected format
236250 ARM_COMPUTE_ASSERT (reorder.validate (tmp_weights.info (), reordered_weights[i].info (), WeightFormat::OHWI,
@@ -240,9 +254,9 @@ class CpuFullyConnectedValidationGenericFixture : public framework::Fixture
240254 }
241255
242256 // Prepare function.
243- prep_pack[0 ].add_const_tensor (arm_compute::TensorType::ACL_SRC_1, &reordered_weights[0 ]);
244- prep_pack[0 ].add_const_tensor (arm_compute::TensorType::ACL_SRC_2, &bias[0 ]);
245- fc.prepare (prep_pack[0 ]);
257+ prep_pack[0 ].add_const_tensor (arm_compute::TensorType::ACL_SRC_1, &reordered_weights[0 ]);
258+ prep_pack[0 ].add_const_tensor (arm_compute::TensorType::ACL_SRC_2, &bias[0 ]);
259+ fc.prepare (prep_pack[0 ]);
246260
247261 if (_test_type == TestType::ConfigureOnceRunMultiThreaded)
248262 {
@@ -295,20 +309,26 @@ class CpuFullyConnectedValidationGenericFixture : public framework::Fixture
295309 void compute_reference (const TensorShape &input_shape,
296310 const TensorShape &weights_shape,
297311 const TensorShape &bias_shape,
298- const TensorShape &output_shape)
312+ const TensorShape &output_shape,
313+ bool with_bias)
299314 {
300315 // Create reference
301316 SimpleTensor<T> ref_src{input_shape, _data_type, 1 , _input_q_info};
302317 SimpleTensor<T> ref_weights{weights_shape, _data_type, 1 , _weight_q_info};
303318 SimpleTensor<TBias> ref_bias{bias_shape, _bias_data_type, 1 , QuantizationInfo ()};
319+
304320 for (int i = 0 ; i < _num_parallel_runs; ++i)
305321 {
306322 // Fill reference
307323 fill (ref_src, 0 + i * 3 );
308324 fill (ref_weights, 1 + i * 3 );
309- fill (ref_bias, 2 + i * 3 );
310325
311- _reference[i] = reference::activation_layer (reference::fully_connected_layer<T>(ref_src, ref_weights, ref_bias, output_shape, _dst_q_info), _activation_info, _dst_q_info);
326+ if (with_bias)
327+ {
328+ fill (ref_bias, 2 + i * 3 );
329+ }
330+
331+ _reference[i] = reference::activation_layer (reference::fully_connected_layer<T>(ref_src, ref_weights, ref_bias, output_shape, _dst_q_info, with_bias), _activation_info, _dst_q_info);
312332 }
313333 }
314334
@@ -344,6 +364,24 @@ class CpuFullyConnectedValidationFixture
344364 }
345365};
346366
367+ template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
368+ class CpuFullyConnectedValidationFixtureNoBias
369+ : public CpuFullyConnectedValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
370+ {
371+ public:
372+ void setup (TensorShape input_shape,
373+ TensorShape weights_shape,
374+ TensorShape bias_shape,
375+ TensorShape output_shape,
376+ DataType data_type,
377+ ActivationLayerInfo activation_info)
378+ {
379+ CpuFullyConnectedValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup (
380+ input_shape, weights_shape, bias_shape, output_shape, data_type,
381+ QuantizationInfo (), activation_info, TestType::ConfigureOnceRunOnce, false );
382+ }
383+ };
384+
347385template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
348386class CpuFullyConnectedThreadSafeValidationFixture
349387 : public CpuFullyConnectedValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
0 commit comments