Skip to content

Commit 0ff92fe

Browse files
[mlir][LLVM][NFC] Simplify computeSizes function (llvm#153588)
Rename `computeSizes` to `computeSize` and make it compute just a single size. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.
1 parent 7d91213 commit 0ff92fe

File tree

4 files changed

+41
-58
lines changed

4 files changed

+41
-58
lines changed

mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,13 @@ class UnrankedMemRefDescriptor : public StructBuilder {
189189
/// `unpack`.
190190
static unsigned getNumUnpackedValues() { return 2; }
191191

192-
/// Builds IR computing the sizes in bytes (suitable for opaque allocation)
193-
/// and appends the corresponding values into `sizes`. `addressSpaces`
194-
/// which must have the same length as `values`, is needed to handle layouts
195-
/// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
196-
static void computeSizes(OpBuilder &builder, Location loc,
192+
/// Builds and returns IR computing the size in bytes (suitable for opaque
193+
/// allocation). `addressSpace` is needed to handle layouts where
194+
/// sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
195+
static Value computeSize(OpBuilder &builder, Location loc,
197196
const LLVMTypeConverter &typeConverter,
198-
ArrayRef<UnrankedMemRefDescriptor> values,
199-
ArrayRef<unsigned> addressSpaces,
200-
SmallVectorImpl<Value> &sizes);
197+
UnrankedMemRefDescriptor desc,
198+
unsigned addressSpace);
201199

202200
/// TODO: The following accessors don't take alignment rules between elements
203201
/// of the descriptor struct into account. For some architectures, it might be

mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
353353
results.push_back(d.memRefDescPtr(builder, loc));
354354
}
355355

356-
void UnrankedMemRefDescriptor::computeSizes(
356+
Value UnrankedMemRefDescriptor::computeSize(
357357
OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
358-
ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
359-
SmallVectorImpl<Value> &sizes) {
360-
if (values.empty())
361-
return;
362-
assert(values.size() == addressSpaces.size() &&
363-
"must provide address space for each descriptor");
358+
UnrankedMemRefDescriptor desc, unsigned addressSpace) {
364359
// Cache the index type.
365360
Type indexType = typeConverter.getIndexType();
366361

@@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes(
371366
builder, loc, indexType,
372367
llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
373368

374-
sizes.reserve(sizes.size() + values.size());
375-
for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
376-
// Emit IR computing the memory necessary to store the descriptor. This
377-
// assumes the descriptor to be
378-
// { type*, type*, index, index[rank], index[rank] }
379-
// and densely packed, so the total size is
380-
// 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
381-
// TODO: consider including the actual size (including eventual padding due
382-
// to data layout) into the unranked descriptor.
383-
Value pointerSize = createIndexAttrConstant(
384-
builder, loc, indexType,
385-
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
386-
Value doublePointerSize =
387-
LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
388-
389-
// (1 + 2 * rank) * sizeof(index)
390-
Value rank = desc.rank(builder, loc);
391-
Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
392-
Value doubleRankIncremented =
393-
LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
394-
Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
395-
doubleRankIncremented, indexSize);
396-
397-
// Total allocation size.
398-
Value allocationSize = LLVM::AddOp::create(
399-
builder, loc, indexType, doublePointerSize, rankIndexSize);
400-
sizes.push_back(allocationSize);
401-
}
369+
// Emit IR computing the memory necessary to store the descriptor. This
370+
// assumes the descriptor to be
371+
// { type*, type*, index, index[rank], index[rank] }
372+
// and densely packed, so the total size is
373+
// 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
374+
// TODO: consider including the actual size (including eventual padding due
375+
// to data layout) into the unranked descriptor.
376+
Value pointerSize = createIndexAttrConstant(
377+
builder, loc, indexType,
378+
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
379+
Value doublePointerSize =
380+
LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
381+
382+
// (1 + 2 * rank) * sizeof(index)
383+
Value rank = desc.rank(builder, loc);
384+
Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
385+
Value doubleRankIncremented =
386+
LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
387+
Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
388+
doubleRankIncremented, indexSize);
389+
390+
// Total allocation size.
391+
Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
392+
doublePointerSize, rankIndexSize);
393+
return allocationSize;
402394
}
403395

404396
Value UnrankedMemRefDescriptor::allocatedPtr(

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,6 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
239239
if (unrankedMemrefs.empty())
240240
return success();
241241

242-
// Compute allocation sizes.
243-
SmallVector<Value> sizes;
244-
UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
245-
unrankedMemrefs, unrankedAddressSpaces,
246-
sizes);
247-
248242
// Get frequently used types.
249243
Type indexType = getTypeConverter()->getIndexType();
250244

@@ -267,8 +261,10 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
267261
Type type = origTypes[i];
268262
if (!isa<UnrankedMemRefType>(type))
269263
continue;
270-
Value allocationSize = sizes[unrankedMemrefPos++];
271264
UnrankedMemRefDescriptor desc(operands[i]);
265+
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
266+
builder, loc, *getTypeConverter(), desc,
267+
unrankedAddressSpaces[unrankedMemrefPos++]);
272268

273269
// Allocate memory, copy, and free the source if necessary.
274270
Value memory =

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering
12461246
auto result = UnrankedMemRefDescriptor::poison(
12471247
rewriter, loc, typeConverter->convertType(resultTypeU));
12481248
result.setRank(rewriter, loc, rank);
1249-
SmallVector<Value, 1> sizes;
1250-
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1251-
result, resultAddrSpace, sizes);
1252-
Value resultUnderlyingSize = sizes.front();
1249+
Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
1250+
rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
12531251
Value resultUnderlyingDesc =
12541252
LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
12551253
rewriter.getI8Type(), resultUnderlyingSize);
@@ -1530,12 +1528,11 @@ struct MemRefReshapeOpLowering
15301528
auto targetDesc = UnrankedMemRefDescriptor::poison(
15311529
rewriter, loc, typeConverter->convertType(targetType));
15321530
targetDesc.setRank(rewriter, loc, resultRank);
1533-
SmallVector<Value, 4> sizes;
1534-
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1535-
targetDesc, addressSpace, sizes);
1531+
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
1532+
rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
15361533
Value underlyingDescPtr = LLVM::AllocaOp::create(
15371534
rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
1538-
sizes.front());
1535+
allocationSize);
15391536
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
15401537

15411538
// Extract pointers and offset from the source memref.

0 commit comments

Comments
 (0)