Skip to content

Commit c81efb7

Browse files
authored
Same improvements for the CPU version
1 parent d173fbd commit c81efb7

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

dlib/cuda/tensor_tools.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -243,29 +243,17 @@ namespace dlib { namespace tt
243243
else if (g_mode == PLANE_WISE)
244244
{
245245
auto is_matrix = [](const auto& tensor) {
246-
return (tensor.num_samples() == 1 && tensor.k() == 1) ||
247-
(tensor.nr() == 1 && tensor.nc() == 1);
246+
return ((tensor.num_samples() * tensor.k() == 1 && tensor.nr() * tensor.nc() > 1) ||
247+
(tensor.num_samples() * tensor.k() > 1 && tensor.nr() * tensor.nc() == 1));
248248
};
249249

250-
long num_samples = std::max({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() });
251-
long num_channels = std::max({ lhs.k(), rhs.k(), dest.k() });
250+
long num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() });
251+
long num_channels = std::min({ lhs.k(), rhs.k(), dest.k() });
252252
const bool lhs_is_matrix = is_matrix(lhs), rhs_is_matrix = is_matrix(rhs), dest_is_matrix = is_matrix(dest);
253253

254254
if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) {
255255
num_samples = num_channels = 1;
256256
}
257-
else
258-
{
259-
auto adjust = [&](const auto& tensor) {
260-
if (!is_matrix(tensor)) {
261-
if (tensor.num_samples() < num_samples) num_samples = tensor.num_samples();
262-
if (tensor.k() < num_channels) num_channels = tensor.k();
263-
}
264-
};
265-
adjust(lhs);
266-
adjust(rhs);
267-
adjust(dest);
268-
}
269257

270258
long lhs_rows = (lhs_is_matrix && lhs.num_samples() > 1) ? lhs.num_samples() : lhs.nr();
271259
long lhs_cols = (lhs_is_matrix && lhs.k() > 1) ? lhs.k() : lhs.nc();

0 commit comments

Comments
 (0)