Skip to content

Commit 4563528

Browse files
committed
Log dtype names on input dtype mismatch (pytorch#7537)
Summary: Update the error message when input tensor scalar type is incorrect. We've seen this get hit a few times and it should be easier to debug than it is. New Message: ``` [method.cpp:834] Input 0 has unexpected scalar type: expected Float but was Byte. ``` Old Message: ``` [method.cpp:826] The 0-th input tensor's scalartype does not meet requirement: found 0 but expected 6 ``` Test Plan: Built executorch bento kernel locally and tested with an incorrect scalar type to view the new error message. ``` [method.cpp:834] Input 0 has unexpected scalar type: expected Float but was Byte. ``` I also locally patched and built the bento kernel with ET_ENABLE_ENUM_STRINGS=0. ``` [method.cpp:834] Input 0 has unexpected scalar type: expected 6 but was 0. ``` Reviewed By: digantdesai, SS-JIA Differential Revision: D67887770 Pulled By: GregoryComer
1 parent 39e8538 commit 4563528

File tree

7 files changed

+90
-5
lines changed

7 files changed

+90
-5
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
2+
#include <executorch/runtime/platform/log.h>
3+
4+
namespace executorch {
5+
namespace runtime {
6+
7+
/**
8+
* Convert a scalar type value to a string representation. If
9+
* ET_ENABLE_ENUM_STRINGS is set (it is on bby default), this will return a
10+
* string name (for example, "Float"). Otherwise, it will return a string
11+
* representation of the index value ("6").
12+
*
13+
* If the user buffer is not large enough to hold the string representation, the
14+
* string will be truncated.
15+
*
16+
* The return value is the number of characters written, or in the case of
17+
* truncation, the number of characters that would be written if the buffer was
18+
* large enough.
19+
*/
20+
size_t scalar_type_to_string(
21+
::executorch::aten::ScalarType t,
22+
char* buffer,
23+
size_t buffer_size) {
24+
#if ET_ENABLE_ENUM_STRINGS
25+
const char* name_str;
26+
#define DEFINE_CASE(unused, name) \
27+
case ::executorch::aten::ScalarType::name: \
28+
name_str = #name; \
29+
break;
30+
31+
switch (t) {
32+
ET_FORALL_SCALAR_TYPES(DEFINE_CASE)
33+
default:
34+
name_str = "Unknown";
35+
break;
36+
}
37+
38+
return snprintf(buffer, buffer_size, "%s", name_str);
39+
#undef DEFINE_CASE
40+
#else
41+
return snprintf(buffer, buffer_size, "%d", static_cast<int>(t));
42+
#endif // ET_ENABLE_ENUM_TO_STRING
43+
}
44+
45+
} // namespace runtime
46+
} // namespace executorch

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,24 @@ struct promote_types {
12941294
CTYPE_ALIAS, \
12951295
__VA_ARGS__))
12961296

1297+
/**
1298+
* Convert a scalar type value to a string representation. If
1299+
* ET_ENABLE_ENUM_STRINGS is set (it is on bby default), this will return a
1300+
* string name (for example, "Float"). Otherwise, it will return a string
1301+
* representation of the index value ("6").
1302+
*
1303+
* If the user buffer is not large enough to hold the string representation, the
1304+
* string will be truncated.
1305+
*
1306+
* The return value is the number of characters written, or in the case of
1307+
* truncation, the number of characters that would be written if the buffer was
1308+
* large enough.
1309+
*/
1310+
size_t scalar_type_to_string(
1311+
::executorch::aten::ScalarType t,
1312+
char* buffer,
1313+
size_t buffer_size);
1314+
12971315
} // namespace runtime
12981316
} // namespace executorch
12991317

runtime/core/exec_aten/util/targets.bzl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@ def define_common_targets():
1919

2020
runtime.cxx_library(
2121
name = "scalar_type_util" + aten_suffix,
22-
srcs = [],
22+
srcs = ["scalar_type_util.cpp"],
2323
exported_headers = [
2424
"scalar_type_util.h",
2525
],
2626
visibility = [
2727
"//executorch/...",
2828
"@EXECUTORCH_CLIENTS",
2929
],
30+
deps = [
31+
"//executorch/runtime/platform:platform",
32+
],
3033
exported_preprocessor_flags = exported_preprocessor_flags_,
3134
exported_deps = exported_deps_,
3235
exported_external_deps = ["libtorch"] if aten_mode else [],

runtime/core/portable_type/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def define_common_targets():
4949
"scalar_type.h",
5050
"qint_types.h",
5151
"bits_types.h",
52+
"string_view.h",
5253
],
5354
visibility = [
5455
"//executorch/extension/...",

runtime/executor/method.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -816,14 +816,24 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
816816
if (e.isTensor()) {
817817
const auto& t_dst = e.toTensor();
818818
const auto& t_src = input_evalue.toTensor();
819+
820+
#if ET_LOG_ENABLED
821+
char dst_type_name[16];
822+
char src_type_name[16];
823+
824+
scalar_type_to_string(
825+
t_dst.scalar_type(), dst_type_name, sizeof(dst_type_name));
826+
scalar_type_to_string(
827+
t_src.scalar_type(), src_type_name, sizeof(src_type_name));
828+
#endif
829+
819830
ET_CHECK_OR_RETURN_ERROR(
820831
t_dst.scalar_type() == t_src.scalar_type(),
821832
InvalidArgument,
822-
"The %zu-th input tensor's scalartype does not meet requirement: found %" PRId8
823-
" but expected %" PRId8,
833+
"Input %zu has unexpected scalar type: expected %s but was %s.",
824834
input_idx,
825-
static_cast<int8_t>(t_src.scalar_type()),
826-
static_cast<int8_t>(t_dst.scalar_type()));
835+
dst_type_name,
836+
src_type_name);
827837
// Reset the shape for the Method's input as the size of forwarded input
828838
// tensor for shape dynamism. Also is a safety check if need memcpy.
829839
Error err = resize_tensor(t_dst, t_src.sizes());

runtime/executor/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def define_common_targets():
8282
"//executorch/runtime/core:evalue" + aten_suffix,
8383
"//executorch/runtime/core:event_tracer" + aten_suffix,
8484
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
85+
"//executorch/runtime/core/exec_aten/util:scalar_type_util" + aten_suffix,
8586
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
8687
"//executorch/runtime/kernel:kernel_runtime_context" + aten_suffix,
8788
"//executorch/runtime/kernel:operator_registry",

runtime/platform/log.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
#define ET_LOG_ENABLED 1
3434
#endif // !defined(ET_LOG_ENABLED)
3535

36+
// Enable ET_ENABLE_ENUM_STRINGS by default. This option gates inclusion of
37+
// enum string names and can be disabled by explicitly setting it to 0.
38+
// #ifndef ET_ENABLE_ENUM_STRINGS
39+
// #define ET_ENABLE_ENUM_STRINGS 1
40+
#// endif
41+
3642
namespace executorch {
3743
namespace runtime {
3844

0 commit comments

Comments
 (0)