Skip to content

Commit f8a1ebc

Browse files
authored
XeGPU Flash Attention implementation. (#703)
1 parent 385bcd2 commit f8a1ebc

File tree

3 files changed

+1084
-9
lines changed

3 files changed

+1084
-9
lines changed

lib/ExecutionEngine/ImexRunnerUtils.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,15 +211,18 @@ void _mlir_ciface_printMaxError(UnrankedMemRefType<T> *M,
211211
DynamicMemRefType<T> DN = DynamicMemRefType<T>(*N);
212212
DynamicMemRefIterator<T> i = DM.begin();
213213
DynamicMemRefIterator<T> j = DN.begin();
214-
std::pair<float, DynamicMemRefIterator<T>> max_rel_err_idx{0.0, DM.begin()};
215-
std::pair<float, DynamicMemRefIterator<T>> max_abs_err_idx{0.0, DM.begin()};
216-
for (; i != DM.end() && j != DN.end(); ++i, ++j) {
217-
const float delta = getFloat(*i) - getFloat(*j);
218-
const float delta_abs = fabs(delta);
219-
if (delta > max_abs_err_idx.first) {
220-
max_abs_err_idx = {delta_abs, i};
221-
max_rel_err_idx = {delta, i};
222-
}
214+
std::pair<double, DynamicMemRefIterator<T>> max_rel_err_idx{0.0, DM.begin()};
215+
std::pair<double, DynamicMemRefIterator<T>> max_abs_err_idx{0.0, DM.begin()};
216+
uint64_t idx = 0;
217+
for (; i != DM.end() && j != DN.end(); ++i, ++j, ++idx) {
218+
const double i_val = getFloat(*i);
219+
const double j_val = getFloat(*j);
220+
const double delta = fabs(i_val - j_val);
221+
const double rel_error = delta / fmax(fabs(i_val), fabs(j_val));
222+
if (delta > max_abs_err_idx.first)
223+
max_abs_err_idx = {delta, i};
224+
if (rel_error > max_rel_err_idx.first)
225+
max_rel_err_idx = {rel_error, i};
223226
}
224227
std::cout << "Max absolute error " << max_abs_err_idx.first
225228
<< " at idx=" << std::distance(DM.begin(), max_abs_err_idx.second)

0 commit comments

Comments
 (0)