@@ -243,29 +243,17 @@ namespace dlib { namespace tt
243243 else if (g_mode == PLANE_WISE)
244244 {
245245 auto is_matrix = [](const auto & tensor) {
246- return (tensor.num_samples () == 1 && tensor.k () == 1 ) ||
247- (tensor.nr () == 1 && tensor.nc () == 1 );
246+ return (( tensor.num_samples () * tensor. k () == 1 && tensor.nr () * tensor. nc () > 1 ) ||
247+ (tensor.num_samples () * tensor. k () > 1 && tensor.nr () * tensor. nc () == 1 ) );
248248 };
249249
250- long num_samples = std::max ({ lhs.num_samples (), rhs.num_samples (), dest.num_samples () });
251- long num_channels = std::max ({ lhs.k (), rhs.k (), dest.k () });
250+ long num_samples = std::min ({ lhs.num_samples (), rhs.num_samples (), dest.num_samples () });
251+ long num_channels = std::min ({ lhs.k (), rhs.k (), dest.k () });
252252 const bool lhs_is_matrix = is_matrix (lhs), rhs_is_matrix = is_matrix (rhs), dest_is_matrix = is_matrix (dest);
253253
254254 if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) {
255255 num_samples = num_channels = 1 ;
256256 }
257- else
258- {
259- auto adjust = [&](const auto & tensor) {
260- if (!is_matrix (tensor)) {
261- if (tensor.num_samples () < num_samples) num_samples = tensor.num_samples ();
262- if (tensor.k () < num_channels) num_channels = tensor.k ();
263- }
264- };
265- adjust (lhs);
266- adjust (rhs);
267- adjust (dest);
268- }
269257
270258 long lhs_rows = (lhs_is_matrix && lhs.num_samples () > 1 ) ? lhs.num_samples () : lhs.nr ();
271259 long lhs_cols = (lhs_is_matrix && lhs.k () > 1 ) ? lhs.k () : lhs.nc ();
0 commit comments