@@ -211,15 +211,18 @@ void _mlir_ciface_printMaxError(UnrankedMemRefType<T> *M,
211
211
DynamicMemRefType<T> DN = DynamicMemRefType<T>(*N);
212
212
DynamicMemRefIterator<T> i = DM.begin ();
213
213
DynamicMemRefIterator<T> j = DN.begin ();
214
- std::pair<float , DynamicMemRefIterator<T>> max_rel_err_idx{0.0 , DM.begin ()};
215
- std::pair<float , DynamicMemRefIterator<T>> max_abs_err_idx{0.0 , DM.begin ()};
216
- for (; i != DM.end () && j != DN.end (); ++i, ++j) {
217
- const float delta = getFloat (*i) - getFloat (*j);
218
- const float delta_abs = fabs (delta);
219
- if (delta > max_abs_err_idx.first ) {
220
- max_abs_err_idx = {delta_abs, i};
221
- max_rel_err_idx = {delta, i};
222
- }
214
+ std::pair<double , DynamicMemRefIterator<T>> max_rel_err_idx{0.0 , DM.begin ()};
215
+ std::pair<double , DynamicMemRefIterator<T>> max_abs_err_idx{0.0 , DM.begin ()};
216
+ uint64_t idx = 0 ;
217
+ for (; i != DM.end () && j != DN.end (); ++i, ++j, ++idx) {
218
+ const double i_val = getFloat (*i);
219
+ const double j_val = getFloat (*j);
220
+ const double delta = fabs (i_val - j_val);
221
+ const double rel_error = delta / fmax (fabs (i_val), fabs (j_val));
222
+ if (delta > max_abs_err_idx.first )
223
+ max_abs_err_idx = {delta, i};
224
+ if (rel_error > max_rel_err_idx.first )
225
+ max_rel_err_idx = {rel_error, i};
223
226
}
224
227
std::cout << " Max absolute error " << max_abs_err_idx.first
225
228
<< " at idx=" << std::distance (DM.begin (), max_abs_err_idx.second )
0 commit comments