|
7 | 7 | #include <cfenv> |
8 | 8 | #include <cstring> |
9 | 9 | #include <iterator> |
| 10 | +#include <optional> |
10 | 11 |
|
11 | 12 | #include "openvino/core/except.hpp" |
12 | 13 | #include "openvino/core/shape.hpp" |
13 | 14 | #include "openvino/op/scatter_elements_update.hpp" |
| 15 | +#include "openvino/reference/rounding_guard.hpp" |
14 | 16 | #include "openvino/reference/utils/coordinate_index.hpp" |
15 | 17 | #include "openvino/reference/utils/coordinate_transform.hpp" |
16 | 18 |
|
@@ -133,25 +135,6 @@ typename std::enable_if<std::is_integral<T>::value, T>::type arithmetic_mean(con |
133 | 135 | return value; |
134 | 136 | } |
135 | 137 |
|
136 | | -template <typename T> |
137 | | -struct RoundingDirectionGuard { |
138 | | - RoundingDirectionGuard() { |
139 | | - if (std::is_integral<T>::value) { |
140 | | - m_original_mode = std::fegetround(); |
141 | | - std::fesetround(FE_DOWNWARD); |
142 | | - } |
143 | | - } |
144 | | - |
145 | | - ~RoundingDirectionGuard() { |
146 | | - if (std::is_integral<T>::value) { |
147 | | - std::fesetround(m_original_mode); |
148 | | - } |
149 | | - } |
150 | | - |
151 | | -private: |
152 | | - decltype(std::fegetround()) m_original_mode; |
153 | | -}; |
154 | | - |
155 | 138 | template <typename DataType> |
156 | 139 | void scatter_elem_update_with_reduction(const int64_t* indices, |
157 | 140 | const DataType* updates, |
@@ -210,11 +193,14 @@ void scatter_elem_update_with_reduction(const int64_t* indices, |
210 | 193 | if (reduction_type == ov::op::v12::ScatterElementsUpdate::Reduction::MEAN) { |
211 | 194 | // this object will change the rounding mode only for integer types which is required to match torch |
212 | 195 | // upon destruction the previously used rounding mode will be restored |
213 | | - RoundingDirectionGuard<DataType> rounding_guard; |
214 | | - for (const auto& counter : mean_reduction_counters) { |
| 196 | + std::optional<RoundingGuard> r_guard; |
| 197 | + if constexpr (std::is_integral_v<DataType>) { |
| 198 | + r_guard.emplace(FE_DOWNWARD); |
| 199 | + } |
| 200 | + for (const auto& [idx, count] : mean_reduction_counters) { |
215 | 201 | // include the initial value in the arithmetic mean divisor (if needed) |
216 | | - const auto N = counter.second + static_cast<int32_t>(use_init_val); |
217 | | - out_buf[counter.first] = arithmetic_mean<DataType>(out_buf[counter.first], N); |
| 202 | + const auto N = count + static_cast<int32_t>(use_init_val); |
| 203 | + out_buf[idx] = arithmetic_mean<DataType>(out_buf[idx], N); |
218 | 204 | } |
219 | 205 | } |
220 | 206 | } |
|
0 commit comments