Skip to content

Commit ee2a225

Browse files
authored
[mlir] Fix correct memset range in OwningMemRef zero-init (#158200)
`OwningMemref` allocates with overprovision + manual alignment. This is fixing the zero-initialization of the data, the existing code was potentially overrunning the allocation: ```cpp memset(descriptor.data, 0, size + desiredAlignment); // ❌ may overrun ``` This is invalid because `descriptor.data` (the aligned pointer) **does not point to the full allocated block** (`size + desiredAlignment`). Zeroing that much from the aligned start can write past the end of the allocation. Instead we only initialize the data from the aligned pointer for the expected buffer size. The padding from [allocatedPtr, alignedDataPtr] is left untouched.
1 parent 86397f5 commit ee2a225

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

mlir/include/mlir/ExecutionEngine/MemRefUtils.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,19 +164,17 @@ class OwningMemRef {
164164
int64_t nElements = 1;
165165
for (int64_t s : shapeAlloc)
166166
nElements *= s;
167-
auto [data, alignedData] =
167+
auto [allocatedPtr, alignedData] =
168168
detail::allocAligned<T>(nElements, allocFun, alignment);
169-
descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
170-
shape, shapeAlloc);
169+
descriptor = detail::makeStridedMemRefDescriptor<Rank>(
170+
allocatedPtr, alignedData, shape, shapeAlloc);
171171
if (init) {
172172
for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
173173
end = descriptor.end();
174174
it != end; ++it)
175175
init(*it, it.getIndices());
176176
} else {
177-
memset(descriptor.data, 0,
178-
nElements * sizeof(T) +
179-
alignment.value_or(detail::nextPowerOf2(sizeof(T))));
177+
memset(alignedData, 0, nElements * sizeof(T));
180178
}
181179
}
182180
/// Take ownership of an existing descriptor with a custom deleter.

mlir/unittests/ExecutionEngine/Invoke.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,24 @@ TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
251251
EXPECT_EQ((a[{2, 1}]), 42.);
252252
}
253253

254+
TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(OwningMemrefZeroInit)) {
255+
constexpr int k = 3;
256+
constexpr int m = 7;
257+
int64_t shape[] = {k, m};
258+
// Use a large alignment to stress the case where the memref data/basePtr are
259+
// disjoint.
260+
int alignment = 8192;
261+
OwningMemRef<float, 2> a(shape, {}, {}, alignment);
262+
ASSERT_EQ(
263+
(void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
264+
a->data);
265+
for (int i = 0; i < k; ++i) {
266+
for (int j = 0; j < m; ++j) {
267+
EXPECT_EQ((a[{i, j}]), 0.);
268+
}
269+
}
270+
}
271+
254272
// A helper function that will be called from the JIT
255273
static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
256274
int32_t coefficient) {

0 commit comments

Comments
 (0)