Skip to content

Commit 08c371d

Browse files
Refactor ImexRunnerUtils to remove some code duplications. (#695)
Co-authored-by: Artem Kroviakov <[email protected]>
1 parent 6840353 commit 08c371d

21 files changed

+290
-191
lines changed

include/imex/ExecutionEngine/ImexRunnerUtils.h

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,36 @@ template <typename T, int N> struct MemRefDescriptor {
4141
int64_t strides[N];
4242
};
4343

44+
template <typename T>
45+
void _mlir_ciface_fillResource1D(UnrankedMemRefType<T> *ptr, // NOLINT
46+
const float value);
47+
template <typename T>
48+
void _mlir_ciface_fillResource1DRandom(UnrankedMemRefType<T> *ptr,
49+
const float lower, const float upper,
50+
const bool genInt);
51+
52+
template <typename T> void _mlir_ciface_printMemref(UnrankedMemRefType<T> *M);
53+
54+
template <typename T>
55+
bool _mlir_ciface_allclose(UnrankedMemRefType<T> *M,
56+
UnrankedMemRefType<float> *N);
57+
58+
template <typename T>
59+
void _mlir_ciface_printAllclose(UnrankedMemRefType<T> *M,
60+
UnrankedMemRefType<float> *N);
61+
62+
template <typename T>
63+
void _mlir_ciface_printMaxError(UnrankedMemRefType<T> *M,
64+
UnrankedMemRefType<T> *N);
65+
4466
extern "C" IMEX_RUNNERUTILS_EXPORT void
45-
_mlir_ciface_fillResource1DBF16(MemRefDescriptor<bf16, 1> *ptr, // NOLINT
46-
float value);
47-
extern "C" IMEX_RUNNERUTILS_EXPORT void
48-
_mlir_ciface_fillResource1DF16(MemRefDescriptor<f16, 1> *ptr, // NOLINT
49-
float value);
50-
extern "C" IMEX_RUNNERUTILS_EXPORT void
51-
_mlir_ciface_fillResource1DF32(MemRefDescriptor<float, 1> *ptr, // NOLINT
52-
float value);
53-
extern "C" IMEX_RUNNERUTILS_EXPORT void
54-
_mlir_ciface_fillMatrixRandomBF16(MemRefDescriptor<bf16, 1> *ptr);
67+
_mlir_ciface_fillResource1DRandomBF16(UnrankedMemRefType<bf16> *ptr,
68+
const float lower, const float upper,
69+
const bool genInt);
5570
extern "C" IMEX_RUNNERUTILS_EXPORT void
56-
_mlir_ciface_fillMatrixRandomF16(MemRefDescriptor<f16, 1> *ptr);
71+
_mlir_ciface_fillResource1DRandomF16(UnrankedMemRefType<f16> *ptr,
72+
const float lower, const float upper,
73+
const bool genInt);
5774

5875
extern "C" IMEX_RUNNERUTILS_EXPORT void
5976
_mlir_ciface_printMemrefBF16(UnrankedMemRefType<bf16> *m);
@@ -84,4 +101,16 @@ extern "C" IMEX_RUNNERUTILS_EXPORT void
84101
_mlir_ciface_printAllcloseF32(UnrankedMemRefType<float> *M,
85102
UnrankedMemRefType<float> *N);
86103

104+
extern "C" IMEX_RUNNERUTILS_EXPORT void
105+
_mlir_ciface_printMaxErrorF16(UnrankedMemRefType<f16> *M,
106+
UnrankedMemRefType<f16> *N);
107+
108+
extern "C" IMEX_RUNNERUTILS_EXPORT void
109+
_mlir_ciface_printMaxErrorBF16(UnrankedMemRefType<bf16> *M,
110+
UnrankedMemRefType<bf16> *N);
111+
112+
extern "C" IMEX_RUNNERUTILS_EXPORT void
113+
_mlir_ciface_printMaxErrorF32(UnrankedMemRefType<float> *M,
114+
UnrankedMemRefType<float> *N);
115+
87116
#endif // IMEX_EXECUTIONENGINE_IMEXRUNNERUTILS_H

lib/ExecutionEngine/ImexRunnerUtils.cpp

Lines changed: 127 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -22,61 +22,78 @@
2222

2323
// NOLINTBEGIN(*-identifier-naming)
2424

25+
/// Fills the given 1D unranked memref with the given float value.
26+
template <typename T>
27+
void _mlir_ciface_fillResource1D(UnrankedMemRefType<T> *ptr, // NOLINT
28+
const float value) {
29+
static_assert(std::is_same_v<T, bf16> || std::is_same_v<T, f16> ||
30+
std::is_same_v<T, float>);
31+
DynamicMemRefType<T> Dptr = DynamicMemRefType<T>(*ptr);
32+
T fill_val(value);
33+
std::fill(Dptr.begin(), Dptr.end(), fill_val);
34+
}
35+
36+
template <typename T>
37+
void _mlir_ciface_fillResource1DRandom(UnrankedMemRefType<T> *ptr,
38+
const float lower, const float upper,
39+
const bool genInt) {
40+
std::random_device rd;
41+
std::mt19937 gen(rd());
42+
std::uniform_real_distribution<> dist(lower, upper);
43+
44+
DynamicMemRefType<T> Dptr = DynamicMemRefType<T>(*ptr);
45+
for (DynamicMemRefIterator<T> i = Dptr.begin(); i != Dptr.end(); ++i) {
46+
*i = T(genInt ? static_cast<int>(dist(gen)) : dist(gen));
47+
}
48+
}
49+
50+
template <typename T> void _mlir_ciface_printMemref(UnrankedMemRefType<T> *M) {
51+
impl::printMemRef(*M);
52+
}
53+
2554
/// Fills the given 1D bf16 memref with the given float value.
2655
extern "C" void
27-
_mlir_ciface_fillResource1DBF16(MemRefDescriptor<bf16, 1> *ptr, // NOLINT
56+
_mlir_ciface_fillResource1DBF16(UnrankedMemRefType<bf16> *ptr, // NOLINT
2857
float value) {
29-
bf16 bf16_val(value);
30-
std::fill_n(ptr->allocated, ptr->sizes[0], bf16_val);
58+
_mlir_ciface_fillResource1D(ptr, value);
3159
}
3260

3361
/// Fills the given 1D f16 memref with the given float value.
3462
extern "C" void
35-
_mlir_ciface_fillResource1DF16(MemRefDescriptor<f16, 1> *ptr, // NOLINT
63+
_mlir_ciface_fillResource1DF16(UnrankedMemRefType<f16> *ptr, // NOLINT
3664
float value) {
37-
f16 f16_val(value);
38-
std::fill_n(ptr->allocated, ptr->sizes[0], f16_val);
65+
_mlir_ciface_fillResource1D(ptr, value);
3966
}
4067

41-
/// Fills 1D memref of bf16 type with random values uniformly
42-
/// distributed in the range (-0.5, 0.5)
68+
/// Fills the given 1D float (f32) memref with the given float value.
4369
extern "C" void
44-
_mlir_ciface_fillMatrixRandomBF16(MemRefDescriptor<bf16, 1> *ptr) {
45-
std::random_device rd;
46-
std::mt19937 gen(rd());
47-
std::uniform_real_distribution<> dist(-0.5f, 0.5f);
48-
49-
for (int i = 0; i < ptr->sizes[0]; i++) {
50-
ptr->allocated[i] = dist(gen);
51-
}
70+
_mlir_ciface_fillResource1DF32(UnrankedMemRefType<float> *ptr, // NOLINT
71+
float value) {
72+
_mlir_ciface_fillResource1D(ptr, value);
5273
}
5374

54-
/// Fills 1D memref of f16 type with random values uniformly
55-
/// distributed in the range (-0.5, 0.5)
75+
/// Fills 1D memref of bf16 type with random values uniformly
5676
extern "C" void
57-
_mlir_ciface_fillMatrixRandomF16(MemRefDescriptor<f16, 1> *ptr) {
58-
std::random_device rd;
59-
std::mt19937 gen(rd());
60-
std::uniform_real_distribution<> dist(-0.5f, 0.5f);
61-
62-
for (int i = 0; i < ptr->sizes[0]; i++) {
63-
ptr->allocated[i] = dist(gen);
64-
}
77+
_mlir_ciface_fillResource1DRandomBF16(UnrankedMemRefType<bf16> *ptr,
78+
const float lower, const float upper,
79+
const bool genInt) {
80+
_mlir_ciface_fillResource1DRandom(ptr, lower, upper, genInt);
6581
}
6682

67-
/// Fills the given 1D float (f32) memref with the given float value.
83+
/// Fills 1D memref of f16 type with random values uniformly
6884
extern "C" void
69-
_mlir_ciface_fillResource1DF32(MemRefDescriptor<float, 1> *ptr, // NOLINT
70-
float value) {
71-
std::fill_n(ptr->allocated, ptr->sizes[0], value);
85+
_mlir_ciface_fillResource1DRandomF16(UnrankedMemRefType<f16> *ptr,
86+
const float lower, const float upper,
87+
const bool genInt) {
88+
_mlir_ciface_fillResource1DRandom(ptr, lower, upper, genInt);
7289
}
7390

7491
extern "C" void _mlir_ciface_printMemrefBF16(UnrankedMemRefType<bf16> *M) {
75-
impl::printMemRef(*M);
92+
_mlir_ciface_printMemref(M);
7693
}
7794

7895
extern "C" void _mlir_ciface_printMemrefF16(UnrankedMemRefType<f16> *M) {
79-
impl::printMemRef(*M);
96+
_mlir_ciface_printMemref(M);
8097
}
8198

8299
extern "C" void printMemrefBF16(int64_t rank, void *ptr) {
@@ -141,96 +158,120 @@ static float bfloat2float(uint16_t bfloatBits) {
141158
return floatBits.f;
142159
}
143160

161+
template <typename T> float getFloat(T val) {
162+
static_assert(std::is_same_v<T, bf16> || std::is_same_v<T, f16> ||
163+
std::is_same_v<T, float>);
164+
if constexpr (std::is_same_v<T, bf16>) {
165+
return bfloat2float(val.bits);
166+
} else if constexpr (std::is_same_v<T, f16>) {
167+
return half2float(val.bits);
168+
} else if constexpr (std::is_same_v<T, float>) {
169+
return val;
170+
}
171+
}
172+
144173
// For information on how to Iterate over UnrankedMemRefType, start with
145174
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
146-
extern "C" bool _mlir_ciface_allcloseF16(UnrankedMemRefType<f16> *M,
147-
UnrankedMemRefType<float> *N) {
175+
template <typename T>
176+
bool _mlir_ciface_allclose(UnrankedMemRefType<T> *M,
177+
UnrankedMemRefType<float> *N) {
148178
// atol, rtol values copied from
149179
// https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
150180
// values may need to adjusted in the future
151181
const float atol = 1e-04;
152182
const float rtol = 1e-03;
153-
DynamicMemRefType<f16> DM = DynamicMemRefType<f16>(*M);
183+
DynamicMemRefType<T> DM = DynamicMemRefType<T>(*M);
154184
DynamicMemRefType<float> DN = DynamicMemRefType<float>(*N);
155-
DynamicMemRefIterator<f16> i = DM.begin();
185+
DynamicMemRefIterator<T> i = DM.begin();
156186
DynamicMemRefIterator<float> j = DN.begin();
157187
for (; i != DM.end() && j != DN.end(); ++i, ++j) {
158-
f16 lhs = *i;
188+
float lhs = getFloat(*i);
159189
float rhs = *j;
160-
if (fabs(half2float(lhs.bits) - rhs) > atol + rtol * fabs(rhs)) {
190+
if (fabs(lhs - rhs) > atol + rtol * fabs(rhs)) {
161191
return false;
162192
}
163193
}
164194
return true;
165195
}
166196

167-
extern "C" bool _mlir_ciface_allcloseBF16(UnrankedMemRefType<bf16> *M,
168-
UnrankedMemRefType<float> *N) {
169-
// atol, rtol values copied from
170-
// https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
171-
// values may need to adjusted in the future
172-
const float atol = 1e-08;
173-
const float rtol = 1e-01;
174-
DynamicMemRefType<bf16> DM = DynamicMemRefType<bf16>(*M);
175-
DynamicMemRefType<float> DN = DynamicMemRefType<float>(*N);
176-
DynamicMemRefIterator<bf16> i = DM.begin();
177-
DynamicMemRefIterator<float> j = DN.begin();
197+
template <typename T>
198+
void _mlir_ciface_printAllclose(UnrankedMemRefType<T> *M,
199+
UnrankedMemRefType<float> *N) {
200+
if (_mlir_ciface_allclose(M, N)) {
201+
std::cout << "[ALLCLOSE: TRUE]\n";
202+
} else {
203+
std::cout << "[ALLCLOSE: FALSE]\n";
204+
}
205+
}
206+
207+
template <typename T>
208+
void _mlir_ciface_printMaxError(UnrankedMemRefType<T> *M,
209+
UnrankedMemRefType<T> *N) {
210+
DynamicMemRefType<T> DM = DynamicMemRefType<T>(*M);
211+
DynamicMemRefType<T> DN = DynamicMemRefType<T>(*N);
212+
DynamicMemRefIterator<T> i = DM.begin();
213+
DynamicMemRefIterator<T> j = DN.begin();
214+
std::pair<float, DynamicMemRefIterator<T>> max_rel_err_idx{0.0, DM.begin()};
215+
std::pair<float, DynamicMemRefIterator<T>> max_abs_err_idx{0.0, DM.begin()};
178216
for (; i != DM.end() && j != DN.end(); ++i, ++j) {
179-
bf16 lhs = *i;
180-
float rhs = *j;
181-
if (fabs(bfloat2float(lhs.bits) - rhs) > atol + rtol * fabs(rhs)) {
182-
return false;
217+
const float delta = getFloat(*i) - getFloat(*j);
218+
const float delta_abs = fabs(delta);
219+
if (delta > max_abs_err_idx.first) {
220+
max_abs_err_idx = {delta_abs, i};
221+
max_rel_err_idx = {delta, i};
183222
}
184223
}
185-
return true;
224+
std::cout << "Max absolute error " << max_abs_err_idx.first
225+
<< " at idx=" << std::distance(DM.begin(), max_abs_err_idx.second)
226+
<< '\n';
227+
std::cout << "Max relative error " << max_rel_err_idx.first
228+
<< " at idx=" << std::distance(DM.begin(), max_rel_err_idx.second)
229+
<< '\n';
230+
}
231+
232+
extern "C" void _mlir_ciface_printMaxErrorF16(UnrankedMemRefType<f16> *M,
233+
UnrankedMemRefType<f16> *N) {
234+
_mlir_ciface_printMaxError(M, N);
235+
}
236+
237+
extern "C" void _mlir_ciface_printMaxErrorBF16(UnrankedMemRefType<bf16> *M,
238+
UnrankedMemRefType<bf16> *N) {
239+
_mlir_ciface_printMaxError(M, N);
240+
}
241+
242+
extern "C" void _mlir_ciface_printMaxErrorF32(UnrankedMemRefType<float> *M,
243+
UnrankedMemRefType<float> *N) {
244+
_mlir_ciface_printMaxError(M, N);
245+
}
246+
247+
extern "C" bool _mlir_ciface_allcloseF16(UnrankedMemRefType<f16> *M,
248+
UnrankedMemRefType<float> *N) {
249+
return _mlir_ciface_allclose(M, N);
250+
}
251+
252+
extern "C" bool _mlir_ciface_allcloseBF16(UnrankedMemRefType<bf16> *M,
253+
UnrankedMemRefType<float> *N) {
254+
return _mlir_ciface_allclose(M, N);
186255
}
187256

188257
extern "C" bool _mlir_ciface_allcloseF32(UnrankedMemRefType<float> *M,
189258
UnrankedMemRefType<float> *N) {
190-
// atol, rtol values copied from
191-
// https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
192-
// values may need to adjusted in the future
193-
const float atol = 1e-08;
194-
const float rtol = 1e-04;
195-
DynamicMemRefType<float> DM = DynamicMemRefType<float>(*M);
196-
DynamicMemRefType<float> DN = DynamicMemRefType<float>(*N);
197-
DynamicMemRefIterator<float> i = DM.begin();
198-
DynamicMemRefIterator<float> j = DN.begin();
199-
for (; i != DM.end() && j != DN.end(); ++i, ++j) {
200-
float lhs = *i;
201-
float rhs = *j;
202-
if (fabs(lhs - rhs) > atol + rtol * fabs(rhs)) {
203-
return false;
204-
}
205-
}
206-
return true;
259+
return _mlir_ciface_allclose(M, N);
207260
}
208261

209262
extern "C" void _mlir_ciface_printAllcloseF16(UnrankedMemRefType<f16> *M,
210263
UnrankedMemRefType<float> *N) {
211-
if (_mlir_ciface_allcloseF16(M, N)) {
212-
std::cout << "[ALLCLOSE: TRUE]\n";
213-
} else {
214-
std::cout << "[ALLCLOSE: FALSE]\n";
215-
}
264+
_mlir_ciface_printAllclose(M, N);
216265
}
217266

218267
extern "C" void _mlir_ciface_printAllcloseBF16(UnrankedMemRefType<bf16> *M,
219268
UnrankedMemRefType<float> *N) {
220-
if (_mlir_ciface_allcloseBF16(M, N)) {
221-
std::cout << "[ALLCLOSE: TRUE]\n";
222-
} else {
223-
std::cout << "[ALLCLOSE: FALSE]\n";
224-
}
269+
_mlir_ciface_printAllclose(M, N);
225270
}
226271

227272
extern "C" void _mlir_ciface_printAllcloseF32(UnrankedMemRefType<float> *M,
228273
UnrankedMemRefType<float> *N) {
229-
if (_mlir_ciface_allcloseF32(M, N)) {
230-
std::cout << "[ALLCLOSE: TRUE]\n";
231-
} else {
232-
std::cout << "[ALLCLOSE: FALSE]\n";
233-
}
274+
_mlir_ciface_printAllclose(M, N);
234275
}
235276

236277
// NOLINTEND(*-identifier-naming)

test/Integration/Dialect/XeGPU/dynamic_memref.vc.mlir

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@ module @gemm attributes {gpu.container_module} {
3535
}
3636
func.func @main() attributes {llvm.emit_c_interface} {
3737
%A = memref.alloc() : memref<8x16xf16>
38-
%A_ = memref.collapse_shape %A [[0, 1]] : memref<8x16xf16> into memref<128xf16>
39-
%A_random = memref.cast %A_ : memref<128xf16> to memref<?xf16>
40-
call @fillMatrixRandomF16(%A_random) : (memref<?xf16>) -> ()
38+
%A_random = memref.cast %A : memref<8x16xf16> to memref<*xf16>
39+
%c_gen_int = arith.constant 0 : i1
40+
%cf_lower = arith.constant -0.5 : f32
41+
%cf_upper = arith.constant 0.5 : f32
42+
43+
call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
44+
4145
%B = call @test(%A) : (memref<8x16xf16>) -> memref<8x16xf32>
4246
%B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32>
4347
%A_cast = memref.cast %A : memref<8x16xf16> to memref<*xf16>
@@ -51,7 +55,7 @@ module @gemm attributes {gpu.container_module} {
5155
}
5256
func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface}
5357
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
54-
func.func private @fillMatrixRandomF16(memref<?xf16>) attributes {llvm.emit_c_interface}
55-
func.func private @fillResource1DF16(memref<?xf16>, f32) attributes {llvm.emit_c_interface}
58+
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
59+
func.func private @fillResource1DF16(memref<*xf16>, f32) attributes {llvm.emit_c_interface}
5660
func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
5761
}

0 commit comments

Comments
 (0)