Skip to content

Commit 9cb9ba0

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Log mismatched values during tensors_have check functions
Summary: D48165168 added `bool tensors_have_same_.*(...)` functions as alternative to the `ET_CHECK_SAME_.*(...)` macros. However, one key difference between the two is that the original macro will provide information about which values caused the check to fail, whereas the new functions don't. This diff brings the new functions to parity of behaviour with the original macros by logging which values caused the check to return `false`. Reviewed By: manuelcandales Differential Revision: D48247982 fbshipit-source-id: 2f4f0ec59428d26949e44f1b04ea2ecd8d824b29
1 parent 48bb99f commit 9cb9ba0

File tree

3 files changed

+280
-20
lines changed

3 files changed

+280
-20
lines changed

kernels/portable/cpu/scalar_utils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ inline ScalarType get_scalar_dtype(Scalar scalar) {
6565
ET_CHECK_MSG(false, "Scalar must be Boolean, Integral or Floating.");
6666
}
6767

68+
inline bool scalars_have_same_dtype(Scalar a, Scalar b) {
69+
ScalarType a_dtype = get_scalar_dtype(a);
70+
ScalarType b_dtype = get_scalar_dtype(b);
71+
if (a_dtype == b_dtype) {
72+
return true;
73+
}
74+
ET_LOG(
75+
Error,
76+
"Expected scalars to have the same dtype, but found %s and %s",
77+
toString(a_dtype),
78+
toString(b_dtype));
79+
return false;
80+
}
81+
6882
/**
6983
* Implement type promotion between a tensor's ScalarType with a Scalar.
7084
* If the Scalar contains a value in the same category of the tensor's

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 223 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -411,40 +411,199 @@ using ScalarType = exec_aten::ScalarType;
411411
// Utility functions for checking tensor attributes
412412
//
413413

414+
inline bool tensor_can_cast_to(
415+
exec_aten::Tensor a,
416+
exec_aten::ScalarType dtype) {
417+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
418+
torch::executor::canCast(a.scalar_type(), dtype),
419+
"Tensor of dtype %s cannot cast to dtype %s",
420+
torch::executor::toString(a.scalar_type()),
421+
torch::executor::toString(dtype));
422+
423+
return true;
424+
}
425+
426+
inline bool tensor_is_bool_type(exec_aten::Tensor t) {
427+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
428+
t.scalar_type() == exec_aten::ScalarType::Bool,
429+
"Expected to find bool type, but tensor has type %s",
430+
torch::executor::toString(t.scalar_type()));
431+
432+
return true;
433+
}
434+
435+
inline bool tensor_is_integral_type(
436+
exec_aten::Tensor t,
437+
bool includeBool = false) {
438+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
439+
torch::executor::isIntegralType(t.scalar_type(), includeBool),
440+
"Expected to find a integral type, but tensor has type %s",
441+
torch::executor::toString(t.scalar_type()));
442+
443+
return true;
444+
}
445+
446+
inline bool tensor_is_floating_type(exec_aten::Tensor t) {
447+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
448+
torch::executor::isFloatingType(t.scalar_type()),
449+
"Expected to find a floating type, but tensor has type %s",
450+
torch::executor::toString(t.scalar_type()));
451+
452+
return true;
453+
}
454+
455+
inline bool tensor_is_complex_type(exec_aten::Tensor t) {
456+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
457+
torch::executor::isComplexType(t.scalar_type()),
458+
"Expected to find a complex type, but tensor has type %s",
459+
torch::executor::toString(t.scalar_type()));
460+
461+
return true;
462+
}
463+
464+
inline bool tensor_is_bits_type(exec_aten::Tensor t) {
465+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
466+
torch::executor::isBitsType(t.scalar_type()),
467+
"Expected to find a bits type, but tensor has type %s",
468+
torch::executor::toString(t.scalar_type()));
469+
470+
return true;
471+
}
472+
414473
inline bool tensors_have_same_dtype(exec_aten::Tensor a, exec_aten::Tensor b) {
415-
return a.scalar_type() == b.scalar_type();
474+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
475+
a.scalar_type() == b.scalar_type(),
476+
ET_TENSOR_CHECK_PREFIX__ ": dtype={%s, %s}",
477+
torch::executor::toString(a.scalar_type()),
478+
torch::executor::toString(b.scalar_type()));
479+
return true;
416480
}
417481

418482
inline bool tensors_have_same_dtype(
419483
exec_aten::Tensor a,
420484
exec_aten::Tensor b,
421485
exec_aten::Tensor c) {
422-
return a.scalar_type() == b.scalar_type() &&
423-
b.scalar_type() == c.scalar_type();
486+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
487+
a.scalar_type() == b.scalar_type() && b.scalar_type() == c.scalar_type(),
488+
ET_TENSOR_CHECK_PREFIX__ ": dtype={%s, %s, %s}",
489+
torch::executor::toString(a.scalar_type()),
490+
torch::executor::toString(b.scalar_type()),
491+
torch::executor::toString(c.scalar_type()));
492+
return true;
493+
}
494+
495+
inline bool tensor_is_rank(exec_aten::Tensor t, size_t rank) {
496+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
497+
t.dim() == rank,
498+
"Expected tensor.dim() to be %zu, but got %zu",
499+
static_cast<size_t>(rank),
500+
static_cast<size_t>(t.dim()));
501+
502+
return true;
503+
}
504+
505+
inline bool tensor_has_dim(exec_aten::Tensor t, int64_t d) {
506+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
507+
d > 0 ? d < t.dim() : t.dim() + d >= 0,
508+
"%zu-dim tensor does not have dim at index %zu",
509+
static_cast<size_t>(t.dim()),
510+
static_cast<size_t>(d));
511+
512+
return true;
513+
}
514+
515+
inline bool tensors_have_same_size_at_dims(
516+
exec_aten::Tensor a,
517+
size_t dim_a,
518+
exec_aten::Tensor b,
519+
size_t dim_b) {
520+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
521+
dim_a < a.dim(),
522+
"Cannot retrieve dim %zu from tensor with dim %zu",
523+
static_cast<size_t>(dim_a),
524+
static_cast<size_t>(a.dim()));
525+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
526+
dim_b < b.dim(),
527+
"Cannot retrieve dim %zu from tensor with dim %zu",
528+
static_cast<size_t>(dim_b),
529+
static_cast<size_t>(b.dim()));
530+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
531+
a.size(dim_a) == b.size(dim_b),
532+
ET_TENSOR_CHECK_PREFIX__
533+
": a.size(%zu) = %zu does not match b.size(%zu) = %zu",
534+
static_cast<size_t>(dim_a),
535+
static_cast<size_t>(a.size(dim_a)),
536+
static_cast<size_t>(dim_b),
537+
static_cast<size_t>(b.size(dim_b)));
538+
539+
return true;
424540
}
425541

426542
inline bool tensors_have_same_shape(exec_aten::Tensor a, exec_aten::Tensor b) {
427-
if (a.numel() != b.numel()) {
428-
return false;
429-
}
430-
if (a.numel() == 1) {
543+
if (a.numel() == 1 && b.numel() == 1) {
431544
// PyTorch operators treat all scalar tensors as the same shape even if
432545
// they have different dims.
433546
return true;
434547
}
435-
// Does a length comparison (ensuring dims are equal) and element-by-element
436-
// comparison (ensuring sizes are equal).
437-
if (a.sizes() != b.sizes()) {
548+
if (!(a.sizes() == b.sizes() && a.numel() == b.numel())) {
549+
ET_LOG(
550+
Error,
551+
ET_TENSOR_CHECK_PREFIX__ ": numel=(%zu, %zu), dim=(%zu, %zu)",
552+
static_cast<size_t>(a.numel()),
553+
static_cast<size_t>(b.numel()),
554+
static_cast<size_t>(a.dim()),
555+
static_cast<size_t>(b.dim()));
556+
for (size_t d = 0; d < ET_MIN2(a.dim(), b.dim()); ++d) {
557+
ET_LOG(
558+
Error,
559+
" size(%zu): (%zu, %zu)",
560+
static_cast<size_t>(d),
561+
static_cast<size_t>(a.size(d)),
562+
static_cast<size_t>(b.size(d)));
563+
}
564+
438565
return false;
439566
}
567+
440568
return true;
441569
}
442570

443571
inline bool tensors_have_same_shape(
444572
exec_aten::Tensor a,
445573
exec_aten::Tensor b,
446574
exec_aten::Tensor c) {
447-
return tensors_have_same_shape(a, b) && tensors_have_same_shape(b, c);
575+
if (a.numel() == 1 && b.numel() == 1 && c.numel() == 1) {
576+
// PyTorch operators treat all scalar tensors as the same shape even if
577+
// they have different dims.
578+
return true;
579+
}
580+
bool cond1 = (a.sizes() == b.sizes()) && (a.numel() == b.numel());
581+
bool cond2 = (b.sizes() == c.sizes()) && (b.numel() == c.numel());
582+
583+
if (!(cond1 && cond2)) {
584+
ET_LOG(
585+
Error,
586+
ET_TENSOR_CHECK_PREFIX__ ": numel=(%zu, %zu, %zu), dim=(%zu, %zu, %zu)",
587+
static_cast<size_t>(a.numel()),
588+
static_cast<size_t>(b.numel()),
589+
static_cast<size_t>(c.numel()),
590+
static_cast<size_t>(a.dim()),
591+
static_cast<size_t>(b.dim()),
592+
static_cast<size_t>(c.dim()));
593+
for (size_t d = 0; d < ET_MIN3(a.dim(), b.dim(), c.dim()); ++d) {
594+
ET_LOG(
595+
Error,
596+
" size(%zu): (%zu, %zu, %zu)",
597+
static_cast<size_t>(d),
598+
static_cast<size_t>(a.size(d)),
599+
static_cast<size_t>(b.size(d)),
600+
static_cast<size_t>(c.size(d)));
601+
}
602+
603+
return false;
604+
}
605+
606+
return true;
448607
}
449608

450609
inline bool tensors_have_same_shape_and_dtype(
@@ -463,14 +622,50 @@ inline bool tensors_have_same_shape_and_dtype(
463622
inline bool tensors_have_same_strides(
464623
exec_aten::Tensor a,
465624
exec_aten::Tensor b) {
466-
return a.strides() == b.strides();
625+
if (a.strides() != b.strides()) {
626+
ET_LOG(
627+
Error,
628+
ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu)",
629+
static_cast<size_t>(a.dim()),
630+
static_cast<size_t>(b.dim()));
631+
for (size_t d = 0; d < ET_MIN2(a.dim(), b.dim()); ++d) {
632+
ET_LOG(
633+
Error,
634+
" stride(%zu): (%zu, %zu)",
635+
static_cast<size_t>(d),
636+
static_cast<size_t>(a.strides()[d]),
637+
static_cast<size_t>(b.strides()[d]));
638+
}
639+
640+
return false;
641+
}
642+
return true;
467643
}
468644

469645
inline bool tensors_have_same_strides(
470646
exec_aten::Tensor a,
471647
exec_aten::Tensor b,
472648
exec_aten::Tensor c) {
473-
return a.strides() == b.strides() && b.strides() == c.strides();
649+
if (!(a.strides() == b.strides() && b.strides() == c.strides())) {
650+
ET_LOG(
651+
Error,
652+
ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu, %zu)",
653+
static_cast<size_t>(a.dim()),
654+
static_cast<size_t>(b.dim()),
655+
static_cast<size_t>(c.dim()));
656+
for (size_t d = 0; d < ET_MIN3(a.dim(), b.dim(), c.dim()); ++d) {
657+
ET_LOG(
658+
Error,
659+
" stride(%zu): (%zu, %zu, %zu)",
660+
static_cast<size_t>(d),
661+
static_cast<size_t>(a.strides()[d]),
662+
static_cast<size_t>(b.strides()[d]),
663+
static_cast<size_t>(c.strides()[d]));
664+
}
665+
666+
return false;
667+
}
668+
return true;
474669
}
475670

476671
inline bool tensor_is_contiguous(exec_aten::Tensor t) {
@@ -480,13 +675,21 @@ inline bool tensor_is_contiguous(exec_aten::Tensor t) {
480675
if (strides.size() == 0) {
481676
return true;
482677
}
483-
if (strides[strides.size() - 1] != 1) {
484-
return false;
485-
}
486-
for (auto i = strides.size() - 1; i > 0; --i) {
487-
if (strides[i - 1] != strides[i] * sizes[i]) {
488-
return false;
489-
}
678+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
679+
strides[strides.size() - 1] == 1,
680+
"Tensor is not contiguous; the stride of the last dimension must be 1, "
681+
"but got %zu",
682+
static_cast<size_t>(strides[strides.size() - 1]));
683+
for (int i = strides.size() - 1; i > 0; --i) {
684+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
685+
strides[i - 1] == strides[i] * sizes[i],
686+
"Tensor is not contiguous; the stride of dim %zu should be equal to "
687+
"strides[%zu] * sizes[%zu] = %zu, but found %zu",
688+
static_cast<size_t>(i - 1),
689+
static_cast<size_t>(i),
690+
static_cast<size_t>(i),
691+
static_cast<size_t>(strides[i] * sizes[i]),
692+
static_cast<size_t>(strides[i - 1]));
490693
}
491694
return true;
492695
}

runtime/core/exec_aten/util/test/tensor_util_test.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1010
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1111
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
12+
#include <executorch/runtime/platform/runtime.h>
1213
#include <executorch/test/utils/DeathTest.h>
1314
#include <cmath>
1415
#include <limits>
@@ -29,6 +30,12 @@ class TensorUtilTest : public ::testing::Test {
2930
TensorFactory<ScalarType::Float> tf_float_;
3031
TensorFactory<ScalarType::Double> tf_double_;
3132
TensorFactory<ScalarType::Bool> tf_bool_;
33+
34+
void SetUp() override {
35+
// As some of these tests cause ET_LOG to be called, the PAL must be
36+
// initialized first by calling runtime_init();
37+
torch::executor::runtime_init();
38+
}
3239
};
3340

3441
TEST_F(TensorUtilTest, IdentityChecks) {
@@ -414,6 +421,30 @@ TEST_F(TensorUtilTest, BoolTensorNotScalarFails) {
414421
// Tests for utility functions that check tensor attributes
415422
//
416423

424+
TEST_F(TensorUtilTest, TensorIsRankTest) {
425+
using namespace torch::executor;
426+
Tensor a = tf_float_.ones({2, 3, 5});
427+
428+
EXPECT_TRUE(tensor_is_rank(a, 3));
429+
EXPECT_FALSE(tensor_is_rank(a, 0));
430+
EXPECT_FALSE(tensor_is_rank(a, 5));
431+
}
432+
433+
TEST_F(TensorUtilTest, TensorHasDimTest) {
434+
using namespace torch::executor;
435+
Tensor a = tf_float_.ones({2, 3, 5});
436+
437+
EXPECT_TRUE(tensor_has_dim(a, 2));
438+
EXPECT_TRUE(tensor_has_dim(a, 1));
439+
EXPECT_TRUE(tensor_has_dim(a, 0));
440+
EXPECT_TRUE(tensor_has_dim(a, -1));
441+
EXPECT_TRUE(tensor_has_dim(a, -2));
442+
EXPECT_TRUE(tensor_has_dim(a, -3));
443+
444+
EXPECT_FALSE(tensor_has_dim(a, -4));
445+
EXPECT_FALSE(tensor_has_dim(a, 3));
446+
}
447+
417448
TEST_F(TensorUtilTest, TensorsHaveSameDtypeTest) {
418449
using namespace torch::executor;
419450
Tensor a = tf_float_.ones({2, 3});
@@ -427,6 +458,18 @@ TEST_F(TensorUtilTest, TensorsHaveSameDtypeTest) {
427458
EXPECT_FALSE(tensors_have_same_dtype(a, b, d));
428459
}
429460

461+
TEST_F(TensorUtilTest, TensorsHaveSameSizeAtDimTest) {
462+
using namespace torch::executor;
463+
Tensor a = tf_float_.ones({2, 3, 4, 5});
464+
Tensor b = tf_float_.ones({5, 4, 3, 2});
465+
466+
EXPECT_TRUE(tensors_have_same_size_at_dims(a, 0, b, 3));
467+
EXPECT_TRUE(tensors_have_same_size_at_dims(a, 1, b, 2));
468+
EXPECT_FALSE(tensors_have_same_size_at_dims(a, 1, b, 0));
469+
EXPECT_FALSE(tensors_have_same_size_at_dims(a, 4, b, 0));
470+
EXPECT_FALSE(tensors_have_same_size_at_dims(a, 2, b, 3));
471+
}
472+
430473
TEST_F(TensorUtilTest, TensorsHaveSameShapeTest) {
431474
using namespace torch::executor;
432475
Tensor a = tf_float_.ones({2, 3});

0 commit comments

Comments
 (0)