88
99#include < executorch/runtime/executor/method.h>
1010
11+ #include < array>
1112#include < cinttypes> // @donotremove
1213#include < cstdint>
1314#include < cstdio>
@@ -823,26 +824,43 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
823824 ET_CHECK_OR_RETURN_ERROR (
824825 input_idx < inputs_size (),
825826 InvalidArgument,
826- " Given input index must be less than the number of inputs in method, but got %zu and %zu " ,
827+ " Input index (%zu) must be less than the number of inputs in method ( %zu). " ,
827828 input_idx,
828829 inputs_size ());
829830
830831 const auto & e = get_value (get_input_index (input_idx));
831- ET_CHECK_OR_RETURN_ERROR (
832- e.isTensor () || e.isScalar (),
833- InvalidArgument,
834- " The %zu-th input in method is expected Tensor or prim, but received %" PRIu32,
835- input_idx,
836- static_cast <uint32_t >(e.tag ));
837832
838- ET_CHECK_OR_RETURN_ERROR (
839- e.tag == input_evalue.tag ,
840- InvalidArgument,
841- " The %zu-th input of method should have the same type as the input_evalue, but get tag %" PRIu32
842- " and tag %" PRIu32,
843- input_idx,
844- static_cast <uint32_t >(e.tag ),
845- static_cast <uint32_t >(input_evalue.tag ));
833+ if (!e.isTensor () && !e.isScalar ()) {
834+ #if ET_LOG_ENABLED
835+ std::array<char , kTagNameBufferSize > tag_name;
836+ tag_to_string (e.tag , tag_name.data (), tag_name.size ());
837+ ET_LOG (
838+ Error,
839+ " Input %zu was expected to be a Tensor or primitive but was %s." ,
840+ input_idx,
841+ tag_name.data ());
842+ #endif
843+
844+ return Error::InvalidArgument;
845+ }
846+
847+ if (e.tag != input_evalue.tag ) {
848+ #if ET_LOG_ENABLED
849+ std::array<char , kTagNameBufferSize > e_tag_name;
850+ std::array<char , kTagNameBufferSize > input_tag_name;
851+ tag_to_string (e.tag , e_tag_name.data (), e_tag_name.size ());
852+ tag_to_string (
853+ input_evalue.tag , input_tag_name.data (), input_tag_name.size ());
854+ ET_LOG (
855+ Error,
856+ " Input %zu was expected to have type %s but was %s." ,
857+ input_idx,
858+ e_tag_name.data (),
859+ input_tag_name.data ());
860+ #endif
861+
862+ return Error::InvalidArgument;
863+ }
846864
847865 if (e.isTensor ()) {
848866 const auto & t_dst = e.toTensor ();
@@ -932,7 +950,12 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
932950 e.toString ().data (),
933951 input_evalue.toString ().data ());
934952 } else {
935- ET_LOG (Error, " Unsupported input type: %d" , (int32_t )e.tag );
953+ #if ET_LOG_ENABLED
954+ std::array<char , kTagNameBufferSize > tag_name;
955+ tag_to_string (e.tag , tag_name.data (), tag_name.size ());
956+ ET_LOG (Error, " Unsupported input type: %s" , tag_name.data ());
957+ #endif
958+
936959 return Error::InvalidArgument;
937960 }
938961 return Error::Ok;
@@ -984,11 +1007,15 @@ Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
9841007 outputs_size ());
9851008
9861009 auto & output = mutable_value (get_output_index (output_idx));
987- ET_CHECK_OR_RETURN_ERROR (
988- output.isTensor (),
989- InvalidArgument,
990- " output type: %zu is not tensor" ,
991- (size_t )output.tag );
1010+ if (!output.isTensor ()) {
1011+ #if ET_LOG_ENABLED
1012+ std::array<char , kTagNameBufferSize > tag_name;
1013+ tag_to_string (output.tag , tag_name.data (), tag_name.size ());
1014+ ET_LOG (Error, " Output type: %s is not a tensor." , tag_name.data ());
1015+ #endif
1016+
1017+ return Error::InvalidArgument;
1018+ }
9921019
9931020 auto tensor_meta = this ->method_meta ().output_tensor_meta (output_idx);
9941021 if (tensor_meta->is_memory_planned ()) {
@@ -1001,11 +1028,16 @@ Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
10011028 }
10021029
10031030 auto & t = output.toTensor ();
1004- ET_CHECK_OR_RETURN_ERROR (
1005- output.isTensor (),
1006- InvalidArgument,
1007- " output type: %zu is not tensor" ,
1008- (size_t )output.tag );
1031+ if (!output.isTensor ()) {
1032+ #if ET_LOG_ENABLED
1033+ std::array<char , kTagNameBufferSize > tag_name;
1034+ tag_to_string (output.tag , tag_name.data (), tag_name.size ());
1035+ ET_LOG (Error, " output type: %s is not a tensor." , tag_name.data ());
1036+ #endif
1037+
1038+ return Error::InvalidArgument;
1039+ }
1040+
10091041 ET_CHECK_OR_RETURN_ERROR (
10101042 t.nbytes () <= size,
10111043 InvalidArgument,
0 commit comments