Skip to content

Commit c52eb68

Browse files
authored
[LLVMGPU] Fix lowering of functions that don't use all bindings (iree-org#19773)
Runtime changes in [some previous PR, ask Ben] mean that now, GPU kernels are passed one pointer for each binding in the pipeline layout, whether or not they are used. When the previous behavior, which was to only pass in the needed pointers one after another, was removed, the GPU code was not updated to mach. This PR updates the conversion from func.func to llvm.func to use one pointer per binding. It also moves the setting of attributes like noundef or nonnull into the function conversion, instead of making the lowerigs for hal.interface.binding.subspan and hal.interface.constant.load reduntantly add those attributes.
1 parent 4215100 commit c52eb68

File tree

2 files changed

+136
-90
lines changed

2 files changed

+136
-90
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp

Lines changed: 106 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -201,27 +201,43 @@ class TestLLVMGPULegalizeOpPass final
201201
}
202202
};
203203

204-
/// Convention with the HAL side to pass kernel arguments.
205-
/// The bindings are ordered based on binding set and binding index then
206-
/// compressed and mapped to dense set of arguments.
207-
/// This function looks at the symbols and return the mapping between
208-
/// InterfaceBindingOp and kernel argument index.
209-
/// For instance if the kernel has (set, bindings) A(0, 1), B(1, 5), C(0, 6) it
210-
/// will return the mapping [A, 0], [C, 1], [B, 2]
211-
static llvm::SmallDenseMap<APInt, size_t>
212-
getKernelArgMapping(Operation *funcOp) {
213-
llvm::SetVector<APInt> usedBindingSet;
214-
funcOp->walk([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
215-
usedBindingSet.insert(subspanOp.getBinding());
216-
});
217-
auto sparseBindings = usedBindingSet.takeVector();
218-
std::sort(sparseBindings.begin(), sparseBindings.end(),
219-
[](APInt lhs, APInt rhs) { return lhs.ult(rhs); });
220-
llvm::SmallDenseMap<APInt, size_t> mapBindingArgIndex;
221-
for (auto [index, binding] : llvm::enumerate(sparseBindings)) {
222-
mapBindingArgIndex[binding] = index;
204+
namespace {
205+
/// A package for the results of `analyzeSubspanOps` to avoid
206+
/// arbitrary tuples. The default values are the results for an unused
207+
/// binding, which is read-only, unused, and in address space 0.
208+
struct BindingProperties {
209+
bool readonly = true;
210+
bool unused = true;
211+
unsigned addressSpace = 0;
212+
};
213+
} // namespace
214+
/// Analyze subspan binding ops to recover properties of the binding, such as
215+
/// if it is read-only and the address space it lives in.
216+
static FailureOr<SmallVector<BindingProperties>>
217+
analyzeSubspans(llvm::SetVector<IREE::HAL::InterfaceBindingSubspanOp> &subspans,
218+
int64_t numBindings, const LLVMTypeConverter *typeConverter) {
219+
SmallVector<BindingProperties> result(numBindings, BindingProperties{});
220+
for (auto subspan : subspans) {
221+
int64_t binding = subspan.getBinding().getSExtValue();
222+
result[binding].unused = false;
223+
result[binding].readonly &= IREE::HAL::bitEnumContainsAny(
224+
subspan.getDescriptorFlags().value_or(IREE::HAL::DescriptorFlags::None),
225+
IREE::HAL::DescriptorFlags::ReadOnly);
226+
unsigned bindingAddrSpace = 0;
227+
auto bindingType = dyn_cast<BaseMemRefType>(subspan.getType());
228+
if (bindingType) {
229+
bindingAddrSpace = *typeConverter->getMemRefAddressSpace(bindingType);
230+
}
231+
if (result[binding].addressSpace != 0 &&
232+
result[binding].addressSpace != bindingAddrSpace) {
233+
return subspan.emitOpError("address space for this op (" +
234+
Twine(bindingAddrSpace) +
235+
") doesn't match previously found space (" +
236+
Twine(result[binding].addressSpace) + ")");
237+
}
238+
result[binding].addressSpace = bindingAddrSpace;
223239
}
224-
return mapBindingArgIndex;
240+
return result;
225241
}
226242

227243
class ConvertFunc : public ConvertToLLVMPattern {
@@ -242,30 +258,46 @@ class ConvertFunc : public ConvertToLLVMPattern {
242258
assert(fnType.getNumInputs() == 0 && fnType.getNumResults() == 0);
243259

244260
TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0);
245-
auto argMapping = getKernelArgMapping(funcOp);
246-
// There may be dead symbols, we pick i32 pointer as default argument type.
247-
SmallVector<Type, 8> llvmInputTypes(
248-
argMapping.size(), LLVM::LLVMPointerType::get(rewriter.getContext()));
261+
// Note: we assume that the pipeline layout is the same for all bindings
262+
// in this function.
263+
IREE::HAL::PipelineLayoutAttr layout;
264+
llvm::SetVector<IREE::HAL::InterfaceBindingSubspanOp> subspans;
249265
funcOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
250-
LLVM::LLVMPointerType llvmType;
251-
if (auto memrefType = dyn_cast<BaseMemRefType>(subspanOp.getType())) {
252-
unsigned addrSpace =
253-
*getTypeConverter()->getMemRefAddressSpace(memrefType);
254-
llvmType = LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
255-
} else {
256-
llvmType = LLVM::LLVMPointerType::get(rewriter.getContext());
266+
if (!layout) {
267+
layout = subspanOp.getLayout();
257268
}
258-
llvmInputTypes[argMapping[subspanOp.getBinding()]] = llvmType;
269+
subspans.insert(subspanOp);
259270
});
260-
// As a convention with HAL, push constants are appended as kernel arguments
261-
// after all the binding inputs.
262-
uint64_t numConstants = 0;
263-
funcOp.walk([&](IREE::HAL::InterfaceConstantLoadOp constantOp) {
264-
numConstants =
265-
std::max(constantOp.getOrdinal().getZExtValue() + 1, numConstants);
271+
272+
funcOp.walk([&](IREE::HAL::InterfaceConstantLoadOp constOp) {
273+
if (!layout) {
274+
layout = constOp.getLayout();
275+
}
276+
return WalkResult::interrupt();
266277
});
267-
llvmInputTypes.resize(argMapping.size() + numConstants,
268-
rewriter.getI32Type());
278+
279+
int64_t numBindings = 0;
280+
int64_t numConstants = 0;
281+
if (layout) {
282+
numConstants = layout.getConstants();
283+
numBindings = layout.getBindings().size();
284+
}
285+
286+
FailureOr<SmallVector<BindingProperties>> maybeBindingsInfo =
287+
analyzeSubspans(subspans, numBindings, getTypeConverter());
288+
if (failed(maybeBindingsInfo))
289+
return failure();
290+
auto bindingsInfo = std::move(*maybeBindingsInfo);
291+
292+
SmallVector<Type, 8> llvmInputTypes;
293+
llvmInputTypes.reserve(numBindings + numConstants);
294+
for (const auto &info : bindingsInfo) {
295+
llvmInputTypes.push_back(
296+
LLVM::LLVMPointerType::get(rewriter.getContext(), info.addressSpace));
297+
}
298+
// All the push constants are i32 and go at the end of the argument list.
299+
llvmInputTypes.resize(numBindings + numConstants, rewriter.getI32Type());
300+
269301
if (!llvmInputTypes.empty())
270302
signatureConverter.addInputs(llvmInputTypes);
271303

@@ -296,6 +328,37 @@ class ConvertFunc : public ConvertToLLVMPattern {
296328
return failure();
297329
}
298330

331+
// Set argument attributes.
332+
Attribute unit = rewriter.getUnitAttr();
333+
for (auto [idx, info] : llvm::enumerate(bindingsInfo)) {
334+
// As a convention with HAL all the kernel argument pointers are 16Bytes
335+
// aligned.
336+
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getAlignAttrName(),
337+
rewriter.getI32IntegerAttr(16));
338+
// It is safe to set the noalias attribute as it is guaranteed that the
339+
// ranges within bindings won't alias.
340+
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getNoAliasAttrName(), unit);
341+
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getNonNullAttrName(), unit);
342+
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getNoUndefAttrName(), unit);
343+
if (info.unused) {
344+
// While LLVM can work this out from the lack of use, we might as well
345+
// be explicit here just to be safe.
346+
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getReadnoneAttrName(),
347+
unit);
348+
} else if (info.readonly) {
349+
// Setting the readonly attribute here will generate non-coherent cache
350+
// loads.
351+
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getReadonlyAttrName(),
352+
unit);
353+
}
354+
}
355+
for (int64_t i = 0; i < numConstants; ++i) {
356+
// Push constants are never `undef`, annotate that here, just as with
357+
// bindings.
358+
newFuncOp.setArgAttr(numBindings + i,
359+
LLVM::LLVMDialect::getNoUndefAttrName(), unit);
360+
}
361+
299362
rewriter.eraseOp(funcOp);
300363
return success();
301364
}
@@ -309,25 +372,6 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
309372
IREE::HAL::InterfaceBindingSubspanOp::getOperationName(), context,
310373
converter) {}
311374

312-
/// Checks all subspanOps with the same binding has readonly attribute
313-
static bool checkAllSubspansReadonly(LLVM::LLVMFuncOp llvmFuncOp,
314-
APInt binding) {
315-
bool allReadOnly = false;
316-
llvmFuncOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp op) {
317-
if (op.getBinding() == binding) {
318-
if (!bitEnumContainsAny(op.getDescriptorFlags().value_or(
319-
IREE::HAL::DescriptorFlags::None),
320-
IREE::HAL::DescriptorFlags::ReadOnly)) {
321-
allReadOnly = false;
322-
return WalkResult::interrupt();
323-
}
324-
allReadOnly = true;
325-
}
326-
return WalkResult::advance();
327-
});
328-
return allReadOnly;
329-
}
330-
331375
LogicalResult
332376
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
333377
ConversionPatternRewriter &rewriter) const override {
@@ -337,35 +381,14 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
337381
return failure();
338382
assert(llvmFuncOp.getNumArguments() > 0);
339383

340-
auto argMapping = getKernelArgMapping(llvmFuncOp);
341384
Location loc = op->getLoc();
342385
auto subspanOp = cast<IREE::HAL::InterfaceBindingSubspanOp>(op);
343386
IREE::HAL::InterfaceBindingSubspanOpAdaptor adaptor(
344387
operands, op->getAttrDictionary());
345388
MemRefType memrefType =
346389
llvm::dyn_cast<MemRefType>(subspanOp.getResult().getType());
347390
mlir::BlockArgument llvmBufferArg =
348-
llvmFuncOp.getArgument(argMapping[subspanOp.getBinding()]);
349-
// As a convention with HAL all the kernel argument pointers are 16Bytes
350-
// aligned.
351-
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
352-
LLVM::LLVMDialect::getAlignAttrName(),
353-
rewriter.getI32IntegerAttr(16));
354-
// It is safe to set the noalias attribute as it is guaranteed that the
355-
// ranges within bindings won't alias.
356-
Attribute unit = rewriter.getUnitAttr();
357-
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
358-
LLVM::LLVMDialect::getNoAliasAttrName(), unit);
359-
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
360-
LLVM::LLVMDialect::getNonNullAttrName(), unit);
361-
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
362-
LLVM::LLVMDialect::getNoUndefAttrName(), unit);
363-
if (checkAllSubspansReadonly(llvmFuncOp, subspanOp.getBinding())) {
364-
// Setting the readonly attribute here will generate non-coherent cache
365-
// loads.
366-
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
367-
LLVM::LLVMDialect::getReadonlyAttrName(), unit);
368-
}
391+
llvmFuncOp.getArgument(subspanOp.getBinding().getZExtValue());
369392
// Add the byte offset.
370393
Value llvmBufferBasePtr = llvmBufferArg;
371394

@@ -468,18 +491,12 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern {
468491
return failure();
469492
assert(llvmFuncOp.getNumArguments() > 0);
470493

471-
auto argMapping = getKernelArgMapping(llvmFuncOp);
472494
auto ireeConstantOp = cast<IREE::HAL::InterfaceConstantLoadOp>(op);
495+
size_t numBindings = ireeConstantOp.getLayout().getBindings().size();
473496
mlir::BlockArgument llvmBufferArg = llvmFuncOp.getArgument(
474-
argMapping.size() + ireeConstantOp.getOrdinal().getZExtValue());
497+
numBindings + ireeConstantOp.getOrdinal().getZExtValue());
475498
assert(llvmBufferArg.getType().isInteger(32));
476499

477-
// Push constants are never `undef`, annotate that here, just as with
478-
// bindings.
479-
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
480-
LLVM::LLVMDialect::getNoUndefAttrName(),
481-
rewriter.getUnitAttr());
482-
483500
Type dstType = getTypeConverter()->convertType(ireeConstantOp.getType());
484501
// llvm.zext requires that the result type has a larger bitwidth.
485502
if (dstType == llvmBufferArg.getType()) {

compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ builtin.module {
3030
// INDEX32-LABEL: llvm.func @abs_ex_dispatch_0
3131
// CHECK-SAME: (%{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef, llvm.readonly},
3232
// CHECK-SAME: %{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
33-
// CHECK-SAME: %{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef})
33+
// CHECK-SAME: %{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
34+
// CHECK-SAME: %{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef, llvm.readnone})
3435
// CHECK: rocdl.workgroup.dim.x
3536
// CHECK: llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
3637
// INDEX32: llvm.getelementptr %{{.*}} : (!llvm.ptr, i32) -> !llvm.ptr, f32
@@ -230,3 +231,31 @@ module {
230231
}
231232
// CHECK-LABEL: llvm.func @emulation_lowering(
232233
// CHECK-NOT: builtin.unrealized_conversion_cast
234+
235+
// -----
236+
// Test that an unused binding still appears in the kernargs
237+
#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
238+
#hal.pipeline.binding<storage_buffer>,
239+
#hal.pipeline.binding<storage_buffer>,
240+
#hal.pipeline.binding<storage_buffer>
241+
]>
242+
builtin.module {
243+
func.func @missing_ptr_dispatch_copy_idx_0() {
244+
%c0 = arith.constant 0 : index
245+
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
246+
%1 = arith.index_castui %0 : i32 to index
247+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%1) flags(ReadOnly) : memref<16xf32, strided<[1], offset : ?>, #gpu.address_space<global>>
248+
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<16xf32, #gpu.address_space<global>>
249+
%4 = memref.load %2[%c0] : memref<16xf32, strided<[1], offset : ?>, #gpu.address_space<global>>
250+
memref.store %4, %3[%c0] : memref<16xf32, #gpu.address_space<global>>
251+
return
252+
}
253+
}
254+
// CHECK-LABEL: llvm.func @missing_ptr_dispatch_copy_idx_0
255+
// CHECK-SAME: (%[[arg0:.+]]: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef, llvm.readonly},
256+
// CHECK-SAME: %[[arg1:.+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef, llvm.readnone},
257+
// CHECK-SAME: %[[arg2:.+]]: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
258+
// CHECK-SAME: %[[arg3:.+]]: i32 {llvm.noundef})
259+
// CHECK: llvm.zext %[[arg3]] : i32 to i64
260+
// CHECK: llvm.insertvalue %[[arg0]]
261+
// CHECK: llvm.insertvalue %[[arg2]]

0 commit comments

Comments
 (0)