File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed
Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -140,12 +140,19 @@ TensorPtr rand_strided(
140140 std::vector<executorch::aten::StridesType> strides,
141141 executorch::aten::ScalarType type,
142142 executorch::aten::TensorShapeDynamism dynamism) {
143+ auto upper_bound = 1 .0f ;
144+ // Adjusts the upper bound to prevent rounding to 1.0 when converting to lower-precision types.
145+ if (type == executorch::aten::ScalarType::Half) {
146+ upper_bound -= static_cast <float >(std::numeric_limits<c10::Half>::epsilon ()) / 2 ;
147+ } else if (type == executorch::aten::ScalarType::BFloat16) {
148+ upper_bound -= static_cast <float >(std::numeric_limits<c10::BFloat16>::epsilon ()) / 2 ;
149+ }
143150 return random_strided (
144151 std::move (sizes),
145152 std::move (strides),
146153 type,
147154 dynamism,
148- std::uniform_real_distribution<float >(0 .0f , 1 . 0f ));
155+ std::uniform_real_distribution<float >(0 .0f , upper_bound ));
149156}
150157
151158TensorPtr randn_strided (
You can’t perform that action at this time.
0 commit comments