|
14 | 14 | #include <cmath> |
15 | 15 | #include <cstddef> // size_t |
16 | 16 | #include <limits> |
17 | | -#include <sstream> |
18 | | -#include <vector> |
19 | 17 |
|
20 | 18 | #include <executorch/runtime/core/array_ref.h> |
21 | 19 | #include <executorch/runtime/core/error.h> |
@@ -488,26 +486,33 @@ inline bool tensor_is_type( |
488 | 486 |
|
489 | 487 | inline bool tensor_is_type( |
490 | 488 | executorch::aten::Tensor t, |
491 | | - const std::vector<executorch::aten::ScalarType>& dtypes) { |
492 | | - if (std::find(dtypes.begin(), dtypes.end(), t.scalar_type()) != |
493 | | - dtypes.end()) { |
494 | | - return true; |
495 | | - } |
| 489 | + executorch::aten::ScalarType dtype, |
| 490 | + executorch::aten::ScalarType dtype2) { |
| 491 | + ET_LOG_MSG_AND_RETURN_IF_FALSE( |
| 492 | + t.scalar_type() == dtype || t.scalar_type() == dtype2, |
| 493 | + "Expected to find %s or %s type, but tensor has type %s", |
| 494 | + torch::executor::toString(dtype), |
| 495 | + torch::executor::toString(dtype2), |
| 496 | + torch::executor::toString(t.scalar_type())); |
496 | 497 |
|
497 | | - std::stringstream dtype_ss; |
498 | | - for (size_t i = 0; i < dtypes.size(); i++) { |
499 | | - if (i != 0) { |
500 | | - dtype_ss << ", "; |
501 | | - } |
502 | | - dtype_ss << torch::executor::toString(dtypes[i]); |
503 | | - } |
| 498 | + return true; |
| 499 | +} |
504 | 500 |
|
| 501 | +inline bool tensor_is_type( |
| 502 | + executorch::aten::Tensor t, |
| 503 | + executorch::aten::ScalarType dtype, |
| 504 | + executorch::aten::ScalarType dtype2, |
| 505 | + executorch::aten::ScalarType dtype3) { |
505 | 506 | ET_LOG_MSG_AND_RETURN_IF_FALSE( |
506 | | - false, |
507 | | - "Expected to find one of %s types, but tensor has type %s", |
508 | | - dtype_ss.str().c_str(), |
| 507 | + t.scalar_type() == dtype || t.scalar_type() == dtype2 || |
| 508 | + t.scalar_type() == dtype3, |
| 509 | + "Expected to find %s, %s, or %s type, but tensor has type %s", |
| 510 | + torch::executor::toString(dtype), |
| 511 | + torch::executor::toString(dtype2), |
| 512 | + torch::executor::toString(dtype3), |
509 | 513 | torch::executor::toString(t.scalar_type())); |
510 | | - return false; |
| 514 | + |
| 515 | + return true; |
511 | 516 | } |
512 | 517 |
|
513 | 518 | inline bool tensor_is_integral_type( |
|
0 commit comments