@@ -59,18 +59,21 @@ struct Descriptor {
5959 Value base;
6060 ValueRange shape;
6161 ValueRange strides;
62+ Value paddingOption;
6263};
6364
6465Descriptor unpackDescriptor (TensorDescType type, ValueRange pack) {
6566 int rank = type.getBlockType ().getRank ();
66- assert (pack.size () == 1 + 2 * static_cast <size_t >(rank) &&
67+ assert (pack.size () == 1 + 2 * static_cast <size_t >(rank) + 1 &&
6768 " Expected tensor descriptors to consist of a pointer, "
68- " followed by 'rank' shape values and 'rank' stride values." );
69+ " followed by 'rank' shape values and 'rank' stride values, "
70+ " followed by a padding option value." );
6971
7072 Descriptor res;
7173 res.base = pack[0 ];
7274 res.shape = pack.slice (1 , rank);
7375 res.strides = pack.slice (1 + rank, rank);
76+ res.paddingOption = pack[1 + 2 * rank];
7477 return res;
7578}
7679
@@ -211,16 +214,30 @@ Value generateMask(OpBuilder &builder, const Location &loc,
211214}
212215
213216Value generateOther (OpBuilder &builder, Location loc, Type scalarTy,
214- ArrayRef<int64_t > blockShape) {
217+ ArrayRef<int64_t > blockShape,
218+ Value paddingOption = nullptr ) {
215219 auto blockTy = RankedTensorType::get (blockShape, scalarTy);
216- auto attr = builder.getZeroAttr (blockTy);
217- return builder.create <arith::ConstantOp>(loc, attr);
220+ if (paddingOption && mlir::isa<FloatType>(scalarTy)) {
221+ auto floatTy = mlir::cast<FloatType>(scalarTy);
222+ auto nan = llvm::APFloat::getNaN (floatTy.getFloatSemantics ());
223+ auto nanValue = builder.create <arith::ConstantOp>(
224+ loc,
225+ SplatElementsAttr::get (blockTy, builder.getFloatAttr (floatTy, nan)));
226+ auto zeroValue = builder.create <arith::ConstantOp>(
227+ loc, SplatElementsAttr::get (blockTy, builder.getZeroAttr (floatTy)));
228+ return builder.create <mlir::arith::SelectOp>(loc, paddingOption, nanValue,
229+ zeroValue);
230+ } else {
231+ auto attr = builder.getZeroAttr (blockTy);
232+ return builder.create <arith::ConstantOp>(loc, attr);
233+ }
218234}
219235
220- Value generateOther (OpBuilder &builder, Location loc, TensorDescType descTy) {
236+ Value generateOther (OpBuilder &builder, Location loc, TensorDescType descTy,
237+ Value paddingOption = nullptr ) {
221238 auto blockTy = descTy.getSignlessBlockType ();
222239 return generateOther (builder, loc, blockTy.getElementType (),
223- blockTy.getShape ());
240+ blockTy.getShape (), paddingOption );
224241}
225242
226243SmallVector<mlir::Value> castToI64 (OpBuilder &builder,
@@ -237,12 +254,17 @@ struct RewriteMakeTensorDesc : OpConversionPattern<triton::MakeTensorDescOp> {
237254 llvm::LogicalResult
238255 matchAndRewrite (triton::MakeTensorDescOp op, OpAdaptor adaptor,
239256 ConversionPatternRewriter &rewriter) const override {
240- SmallVector<mlir::Value> ptrShapeStrides ;
241- llvm::append_values (ptrShapeStrides , adaptor.getBase ());
242- llvm::append_range (ptrShapeStrides ,
257+ SmallVector<mlir::Value> ptrShapeStridesPaddingOption ;
258+ llvm::append_values (ptrShapeStridesPaddingOption , adaptor.getBase ());
259+ llvm::append_range (ptrShapeStridesPaddingOption ,
243260 castToI64 (rewriter, adaptor.getShape ()));
244- llvm::append_range (ptrShapeStrides, adaptor.getStrides ());
245- rewriter.replaceOpWithMultiple (op, {ptrShapeStrides});
261+ llvm::append_range (ptrShapeStridesPaddingOption, adaptor.getStrides ());
262+ auto paddingOption = rewriter.create <mlir::arith::ConstantOp>(
263+ op.getLoc (), rewriter.getI1Type (),
264+ rewriter.getBoolAttr (adaptor.getPadding () ==
265+ triton::PaddingOption::PAD_NAN));
266+ llvm::append_values (ptrShapeStridesPaddingOption, paddingOption);
267+ rewriter.replaceOpWithMultiple (op, {ptrShapeStridesPaddingOption});
246268 return mlir::success ();
247269 }
248270};
@@ -258,12 +280,11 @@ struct RewriteLoadPattern : OpConversionPattern<triton::DescriptorLoadOp> {
258280 auto descTy = op.getDesc ().getType ();
259281 auto desc = unpackDescriptor (descTy, adaptor.getDesc ());
260282 auto offsets = castToI64 (rewriter, op.getIndices ());
261-
283+ auto other = generateOther (rewriter, loc, descTy, desc. paddingOption );
262284 auto newLoad = rewriter.replaceOpWithNewOp <triton::LoadOp>(
263285 op, generatePtr (rewriter, loc, blockShape, desc, offsets),
264- generateMask (rewriter, loc, blockShape, desc, offsets),
265- generateOther (rewriter, loc, descTy), triton::CacheModifier::NONE,
266- triton::EvictionPolicy::NORMAL, false );
286+ generateMask (rewriter, loc, blockShape, desc, offsets), other,
287+ triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL, false );
267288 newLoad->setAttrs (filterSegmentSizes (op->getAttrs ()));
268289
269290 return llvm::success ();
@@ -327,7 +348,7 @@ struct RewriteGatherPattern : OpConversionPattern<triton::DescriptorGatherOp> {
327348 rewriter, loc, blockShape, desc, op.getXOffsets (), op.getYOffset ());
328349 auto other = generateOther (rewriter, loc,
329350 descTy.getSignlessBlockType ().getElementType (),
330- blockShape);
351+ blockShape, desc. paddingOption );
331352 auto newLoad = rewriter.replaceOpWithNewOp <triton::LoadOp>(
332353 op, ptr, mask, other, triton::CacheModifier::NONE,
333354 triton::EvictionPolicy::NORMAL, false );
@@ -471,13 +492,14 @@ class TritonRewriteTensorDescriptorToPointerPass
471492 converter.addConversion ([](mlir::triton::TensorDescType t,
472493 llvm::SmallVectorImpl<mlir::Type> &out) {
473494 // We convert a tensor descriptor into an pointer, and a shape and stride
474- // for each dimension, i.e., we create 1+2*rank values. Note that tensor
475- // descriptors may be signed/unsigned integers whereas pointers should
476- // always be signless.
495+ // for each dimension, and padding option. i.e., we create 1+2*rank+1
496+ // values. Note that tensor descriptors may be signed/unsigned integers
497+ // whereas pointers should always be signless.
477498 auto tensorType = t.getSignlessBlockType ();
478499 out.push_back (triton::getPointerType (tensorType.getElementType ()));
479500 out.insert (out.end (), 2 * tensorType.getRank (),
480501 mlir::IntegerType::get (t.getContext (), 64 ));
502+ out.push_back (mlir::IntegerType::get (t.getContext (), 1 ));
481503 return mlir::success ();
482504 });
483505
0 commit comments