@@ -201,27 +201,43 @@ class TestLLVMGPULegalizeOpPass final
201201 }
202202};
203203
204- // / Convention with the HAL side to pass kernel arguments.
205- // / The bindings are ordered based on binding set and binding index then
206- // / compressed and mapped to dense set of arguments.
207- // / This function looks at the symbols and return the mapping between
208- // / InterfaceBindingOp and kernel argument index.
209- // / For instance if the kernel has (set, bindings) A(0, 1), B(1, 5), C(0, 6) it
210- // / will return the mapping [A, 0], [C, 1], [B, 2]
211- static llvm::SmallDenseMap<APInt, size_t >
212- getKernelArgMapping (Operation *funcOp) {
213- llvm::SetVector<APInt> usedBindingSet;
214- funcOp->walk ([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
215- usedBindingSet.insert (subspanOp.getBinding ());
216- });
217- auto sparseBindings = usedBindingSet.takeVector ();
218- std::sort (sparseBindings.begin (), sparseBindings.end (),
219- [](APInt lhs, APInt rhs) { return lhs.ult (rhs); });
220- llvm::SmallDenseMap<APInt, size_t > mapBindingArgIndex;
221- for (auto [index, binding] : llvm::enumerate (sparseBindings)) {
222- mapBindingArgIndex[binding] = index;
204+ namespace {
205+ // / A package for the results of `analyzeSubspanOps` to avoid
206+ // / arbitrary tuples. The default values are the results for an unused
207+ // / binding, which is read-only, unused, and in address space 0.
208+ struct BindingProperties {
209+ bool readonly = true ;
210+ bool unused = true ;
211+ unsigned addressSpace = 0 ;
212+ };
213+ } // namespace
214+ // / Analyze subspan binding ops to recover properties of the binding, such as
215+ // / if it is read-only and the address space it lives in.
216+ static FailureOr<SmallVector<BindingProperties>>
217+ analyzeSubspans (llvm::SetVector<IREE::HAL::InterfaceBindingSubspanOp> &subspans,
218+ int64_t numBindings, const LLVMTypeConverter *typeConverter) {
219+ SmallVector<BindingProperties> result (numBindings, BindingProperties{});
220+ for (auto subspan : subspans) {
221+ int64_t binding = subspan.getBinding ().getSExtValue ();
222+ result[binding].unused = false ;
223+ result[binding].readonly &= IREE::HAL::bitEnumContainsAny (
224+ subspan.getDescriptorFlags ().value_or (IREE::HAL::DescriptorFlags::None),
225+ IREE::HAL::DescriptorFlags::ReadOnly);
226+ unsigned bindingAddrSpace = 0 ;
227+ auto bindingType = dyn_cast<BaseMemRefType>(subspan.getType ());
228+ if (bindingType) {
229+ bindingAddrSpace = *typeConverter->getMemRefAddressSpace (bindingType);
230+ }
231+ if (result[binding].addressSpace != 0 &&
232+ result[binding].addressSpace != bindingAddrSpace) {
233+ return subspan.emitOpError (" address space for this op (" +
234+ Twine (bindingAddrSpace) +
235+ " ) doesn't match previously found space (" +
236+ Twine (result[binding].addressSpace ) + " )" );
237+ }
238+ result[binding].addressSpace = bindingAddrSpace;
223239 }
224- return mapBindingArgIndex ;
240+ return result ;
225241}
226242
227243class ConvertFunc : public ConvertToLLVMPattern {
@@ -242,30 +258,46 @@ class ConvertFunc : public ConvertToLLVMPattern {
242258 assert (fnType.getNumInputs () == 0 && fnType.getNumResults () == 0 );
243259
244260 TypeConverter::SignatureConversion signatureConverter (/* numOrigInputs=*/ 0 );
245- auto argMapping = getKernelArgMapping (funcOp);
246- // There may be dead symbols, we pick i32 pointer as default argument type .
247- SmallVector<Type, 8 > llvmInputTypes (
248- argMapping. size (), LLVM::LLVMPointerType::get (rewriter. getContext ())) ;
261+ // Note: we assume that the pipeline layout is the same for all bindings
262+ // in this function .
263+ IREE::HAL::PipelineLayoutAttr layout;
264+ llvm::SetVector<IREE::HAL::InterfaceBindingSubspanOp> subspans ;
249265 funcOp.walk ([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
250- LLVM::LLVMPointerType llvmType;
251- if (auto memrefType = dyn_cast<BaseMemRefType>(subspanOp.getType ())) {
252- unsigned addrSpace =
253- *getTypeConverter ()->getMemRefAddressSpace (memrefType);
254- llvmType = LLVM::LLVMPointerType::get (rewriter.getContext (), addrSpace);
255- } else {
256- llvmType = LLVM::LLVMPointerType::get (rewriter.getContext ());
266+ if (!layout) {
267+ layout = subspanOp.getLayout ();
257268 }
258- llvmInputTypes[argMapping[subspanOp. getBinding ()]] = llvmType ;
269+ subspans. insert (subspanOp) ;
259270 });
260- // As a convention with HAL, push constants are appended as kernel arguments
261- // after all the binding inputs.
262- uint64_t numConstants = 0 ;
263- funcOp. walk ([&](IREE::HAL::InterfaceConstantLoadOp constantOp) {
264- numConstants =
265- std::max (constantOp. getOrdinal (). getZExtValue () + 1 , numConstants );
271+
272+ funcOp. walk ([&](IREE::HAL::InterfaceConstantLoadOp constOp) {
273+ if (!layout) {
274+ layout = constOp. getLayout ();
275+ }
276+ return WalkResult::interrupt ( );
266277 });
267- llvmInputTypes.resize (argMapping.size () + numConstants,
268- rewriter.getI32Type ());
278+
279+ int64_t numBindings = 0 ;
280+ int64_t numConstants = 0 ;
281+ if (layout) {
282+ numConstants = layout.getConstants ();
283+ numBindings = layout.getBindings ().size ();
284+ }
285+
286+ FailureOr<SmallVector<BindingProperties>> maybeBindingsInfo =
287+ analyzeSubspans (subspans, numBindings, getTypeConverter ());
288+ if (failed (maybeBindingsInfo))
289+ return failure ();
290+ auto bindingsInfo = std::move (*maybeBindingsInfo);
291+
292+ SmallVector<Type, 8 > llvmInputTypes;
293+ llvmInputTypes.reserve (numBindings + numConstants);
294+ for (const auto &info : bindingsInfo) {
295+ llvmInputTypes.push_back (
296+ LLVM::LLVMPointerType::get (rewriter.getContext (), info.addressSpace ));
297+ }
298+ // All the push constants are i32 and go at the end of the argument list.
299+ llvmInputTypes.resize (numBindings + numConstants, rewriter.getI32Type ());
300+
269301 if (!llvmInputTypes.empty ())
270302 signatureConverter.addInputs (llvmInputTypes);
271303
@@ -296,6 +328,37 @@ class ConvertFunc : public ConvertToLLVMPattern {
296328 return failure ();
297329 }
298330
331+ // Set argument attributes.
332+ Attribute unit = rewriter.getUnitAttr ();
333+ for (auto [idx, info] : llvm::enumerate (bindingsInfo)) {
334+ // As a convention with HAL all the kernel argument pointers are 16Bytes
335+ // aligned.
336+ newFuncOp.setArgAttr (idx, LLVM::LLVMDialect::getAlignAttrName (),
337+ rewriter.getI32IntegerAttr (16 ));
338+ // It is safe to set the noalias attribute as it is guaranteed that the
339+ // ranges within bindings won't alias.
340+ newFuncOp.setArgAttr (idx, LLVM::LLVMDialect::getNoAliasAttrName (), unit);
341+ newFuncOp.setArgAttr (idx, LLVM::LLVMDialect::getNonNullAttrName (), unit);
342+ newFuncOp.setArgAttr (idx, LLVM::LLVMDialect::getNoUndefAttrName (), unit);
343+ if (info.unused ) {
344+ // While LLVM can work this out from the lack of use, we might as well
345+ // be explicit here just to be safe.
346+ newFuncOp.setArgAttr (idx, LLVM::LLVMDialect::getReadnoneAttrName (),
347+ unit);
348+ } else if (info.readonly ) {
349+ // Setting the readonly attribute here will generate non-coherent cache
350+ // loads.
351+ newFuncOp.setArgAttr (idx, LLVM::LLVMDialect::getReadonlyAttrName (),
352+ unit);
353+ }
354+ }
355+ for (int64_t i = 0 ; i < numConstants; ++i) {
356+ // Push constants are never `undef`, annotate that here, just as with
357+ // bindings.
358+ newFuncOp.setArgAttr (numBindings + i,
359+ LLVM::LLVMDialect::getNoUndefAttrName (), unit);
360+ }
361+
299362 rewriter.eraseOp (funcOp);
300363 return success ();
301364 }
@@ -309,25 +372,6 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
309372 IREE::HAL::InterfaceBindingSubspanOp::getOperationName (), context,
310373 converter) {}
311374
312- // / Checks all subspanOps with the same binding has readonly attribute
313- static bool checkAllSubspansReadonly (LLVM::LLVMFuncOp llvmFuncOp,
314- APInt binding) {
315- bool allReadOnly = false ;
316- llvmFuncOp.walk ([&](IREE::HAL::InterfaceBindingSubspanOp op) {
317- if (op.getBinding () == binding) {
318- if (!bitEnumContainsAny (op.getDescriptorFlags ().value_or (
319- IREE::HAL::DescriptorFlags::None),
320- IREE::HAL::DescriptorFlags::ReadOnly)) {
321- allReadOnly = false ;
322- return WalkResult::interrupt ();
323- }
324- allReadOnly = true ;
325- }
326- return WalkResult::advance ();
327- });
328- return allReadOnly;
329- }
330-
331375 LogicalResult
332376 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
333377 ConversionPatternRewriter &rewriter) const override {
@@ -337,35 +381,14 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
337381 return failure ();
338382 assert (llvmFuncOp.getNumArguments () > 0 );
339383
340- auto argMapping = getKernelArgMapping (llvmFuncOp);
341384 Location loc = op->getLoc ();
342385 auto subspanOp = cast<IREE::HAL::InterfaceBindingSubspanOp>(op);
343386 IREE::HAL::InterfaceBindingSubspanOpAdaptor adaptor (
344387 operands, op->getAttrDictionary ());
345388 MemRefType memrefType =
346389 llvm::dyn_cast<MemRefType>(subspanOp.getResult ().getType ());
347390 mlir::BlockArgument llvmBufferArg =
348- llvmFuncOp.getArgument (argMapping[subspanOp.getBinding ()]);
349- // As a convention with HAL all the kernel argument pointers are 16Bytes
350- // aligned.
351- llvmFuncOp.setArgAttr (llvmBufferArg.getArgNumber (),
352- LLVM::LLVMDialect::getAlignAttrName (),
353- rewriter.getI32IntegerAttr (16 ));
354- // It is safe to set the noalias attribute as it is guaranteed that the
355- // ranges within bindings won't alias.
356- Attribute unit = rewriter.getUnitAttr ();
357- llvmFuncOp.setArgAttr (llvmBufferArg.getArgNumber (),
358- LLVM::LLVMDialect::getNoAliasAttrName (), unit);
359- llvmFuncOp.setArgAttr (llvmBufferArg.getArgNumber (),
360- LLVM::LLVMDialect::getNonNullAttrName (), unit);
361- llvmFuncOp.setArgAttr (llvmBufferArg.getArgNumber (),
362- LLVM::LLVMDialect::getNoUndefAttrName (), unit);
363- if (checkAllSubspansReadonly (llvmFuncOp, subspanOp.getBinding ())) {
364- // Setting the readonly attribute here will generate non-coherent cache
365- // loads.
366- llvmFuncOp.setArgAttr (llvmBufferArg.getArgNumber (),
367- LLVM::LLVMDialect::getReadonlyAttrName (), unit);
368- }
391+ llvmFuncOp.getArgument (subspanOp.getBinding ().getZExtValue ());
369392 // Add the byte offset.
370393 Value llvmBufferBasePtr = llvmBufferArg;
371394
@@ -468,18 +491,12 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern {
468491 return failure ();
469492 assert (llvmFuncOp.getNumArguments () > 0 );
470493
471- auto argMapping = getKernelArgMapping (llvmFuncOp);
472494 auto ireeConstantOp = cast<IREE::HAL::InterfaceConstantLoadOp>(op);
495+ size_t numBindings = ireeConstantOp.getLayout ().getBindings ().size ();
473496 mlir::BlockArgument llvmBufferArg = llvmFuncOp.getArgument (
474- argMapping. size () + ireeConstantOp.getOrdinal ().getZExtValue ());
497+ numBindings + ireeConstantOp.getOrdinal ().getZExtValue ());
475498 assert (llvmBufferArg.getType ().isInteger (32 ));
476499
477- // Push constants are never `undef`, annotate that here, just as with
478- // bindings.
479- llvmFuncOp.setArgAttr (llvmBufferArg.getArgNumber (),
480- LLVM::LLVMDialect::getNoUndefAttrName (),
481- rewriter.getUnitAttr ());
482-
483500 Type dstType = getTypeConverter ()->convertType (ireeConstantOp.getType ());
484501 // llvm.zext requires that the result type has a larger bitwidth.
485502 if (dstType == llvmBufferArg.getType ()) {
0 commit comments