Skip to content

Commit 8c77ef8

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add convert template
Reviewed By: kirklandsign Differential Revision: D48288081 fbshipit-source-id: affcab01a97ef240f55e4594c23b9e3e12f683bc
1 parent bd95cf9 commit 8c77ef8

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,31 @@ inline bool canCast(
455455
return true;
456456
}
457457

458+
/**
459+
* When casting from floating point to integral type, if the floating value is
460+
* outside the integral type range, then an error is thrown if sanitization is
461+
* enabled. To circumvent this, we cast the floating point to int64_t first.
462+
*/
463+
template <
464+
typename To,
465+
typename From,
466+
typename std::enable_if<
467+
(std::is_floating_point<From>::value && std::is_integral<To>::value),
468+
int>::type = 0>
469+
To convert(From val) {
470+
return static_cast<To>(static_cast<int64_t>(val));
471+
}
472+
473+
template <
474+
typename To,
475+
typename From,
476+
typename std::enable_if<
477+
!(std::is_floating_point<From>::value && std::is_integral<To>::value),
478+
int>::type = 0>
479+
To convert(From val) {
480+
return static_cast<To>(val);
481+
}
482+
458483
/**
459484
* Implements type promotion rules that are consistent with ATen behaviour,
460485
* which in turn is consistent with NumPy's promote_types.

0 commit comments

Comments
 (0)