@@ -114,7 +114,7 @@ static SmallVector<utils::IteratorType> getNParallelLoopsAttrs(unsigned n) {
114114}
115115
116116// if order is empty, transpose the last two dimensions
117- // otherwise, use the provided order.
117+ // otherwise, use the provided order.
118118// The order must be a permutation of the source rank.
119119static Value getTransposedValue (Value source, const Location loc,
120120 ConversionPatternRewriter &rewriter,
@@ -521,9 +521,6 @@ struct StoreConverter : public OpConversionPattern<triton::StoreOp> {
521521 if (!isa<ShapedType>(val.getType ())) {
522522 auto sMemRef =
523523 PtrAnalysis::getScalarMemRef (op.getPtr (), ptr, loc, rewriter);
524- auto index =
525- rewriter.create <arith::ConstantOp>(loc, rewriter.getIndexAttr (0 ))
526- .getResult ();
527524 auto zeroMap = AffineMap::getConstantMap (0 , rewriter.getContext ());
528525 rewriter.create <affine::AffineStoreOp>(loc, val, sMemRef , zeroMap,
529526 std::nullopt );
@@ -891,7 +888,8 @@ struct CallConverter : public OpConversionPattern<triton::CallOp> {
891888 ConversionPatternRewriter &rewriter) const override {
892889 SmallVector<Value> args = adaptor.getOperands ();
893890
894- // We need to pass extra arguments added by addProgramInfo which are num_programs and program_ids
891+ // We need to pass extra arguments added by addProgramInfo which are
892+ // num_programs and program_ids
895893 if (FuncOp parentFunc = op->getParentOfType <triton::FuncOp>()) {
896894 SymbolRefAttr calleeAttr = op.getCalleeAttr ();
897895 StringRef calleeName = calleeAttr.getRootReference ();
@@ -914,12 +912,12 @@ struct CallConverter : public OpConversionPattern<triton::CallOp> {
914912 }
915913 }
916914
917- auto call = rewriter.create <func::CallOp>(
918- op. getLoc (), op. getCallee (), op.getResultTypes (), args);
915+ auto call = rewriter.create <func::CallOp>(op. getLoc (), op. getCallee (),
916+ op.getResultTypes (), args);
919917
920918 if (!call) {
921- op.emitError (" Failed to create func::CallOp" );
922- return failure ();
919+ op.emitError (" Failed to create func::CallOp" );
920+ return failure ();
923921 }
924922
925923 rewriter.replaceOp (op, call);
@@ -949,15 +947,17 @@ struct FpToFpConverter : public OpConversionPattern<triton::FpToFpOp> {
949947 auto resultWidth = getBitWidth (resultType);
950948
951949 assert (operandWidth.has_value () && resultWidth.has_value () &&
952- " Not a float-like operand or result" );
950+ " Not a float-like operand or result" );
953951
954952 if (operandWidth.value () > resultWidth.value ()) {
955- Value truncatedValue = rewriter.create <arith::TruncFOp>(op.getLoc (), resultType, op.getOperand ());
953+ Value truncatedValue = rewriter.create <arith::TruncFOp>(
954+ op.getLoc (), resultType, op.getOperand ());
956955 rewriter.replaceOp (op, truncatedValue);
957956 return success ();
958957 }
959958
960- Value extendedValue = rewriter.create <arith::ExtFOp>(op.getLoc (), resultType, op.getOperand ());
959+ Value extendedValue = rewriter.create <arith::ExtFOp>(
960+ op.getLoc (), resultType, op.getOperand ());
961961 rewriter.replaceOp (op, extendedValue);
962962
963963 return success ();
@@ -991,14 +991,15 @@ struct ClampConverter : public OpConversionPattern<triton::ClampFOp> {
991991 }
992992};
993993
994- struct PreciseSqrtConverter : public OpConversionPattern <triton::PreciseSqrtOp> {
994+ struct PreciseSqrtConverter
995+ : public OpConversionPattern<triton::PreciseSqrtOp> {
995996 using OpConversionPattern<triton::PreciseSqrtOp>::OpConversionPattern;
996997
997998 LogicalResult
998999 matchAndRewrite (triton::PreciseSqrtOp op, OpAdaptor adaptor,
9991000 ConversionPatternRewriter &rewriter) const override {
1000- auto replacement = rewriter. create <math::SqrtOp>(
1001- op.getLoc (), adaptor.getOperands ());
1001+ auto replacement =
1002+ rewriter. create <math::SqrtOp>( op.getLoc (), adaptor.getOperands ());
10021003
10031004 rewriter.replaceOp (op, replacement);
10041005 return success ();
@@ -1011,8 +1012,8 @@ struct PreciseDivConverter : public OpConversionPattern<triton::PreciseDivFOp> {
10111012 LogicalResult
10121013 matchAndRewrite (triton::PreciseDivFOp op, OpAdaptor adaptor,
10131014 ConversionPatternRewriter &rewriter) const override {
1014- auto replacement = rewriter. create <arith::DivFOp>(
1015- op.getLoc (), adaptor.getOperands ());
1015+ auto replacement =
1016+ rewriter. create <arith::DivFOp>( op.getLoc (), adaptor.getOperands ());
10161017
10171018 rewriter.replaceOp (op, replacement);
10181019 return success ();
@@ -1050,10 +1051,10 @@ struct SplitConverter : public OpConversionPattern<triton::SplitOp> {
10501051
10511052 SmallVector<OpFoldResult> offsets (shape.size (), rewriter.getIndexAttr (0 ));
10521053 SmallVector<OpFoldResult> strides (shape.size (), rewriter.getIndexAttr (1 ));
1053- SmallVector<OpFoldResult> sizes =
1054- llvm::to_vector ( llvm::map_range (shape, [&](int64_t dim) -> OpFoldResult {
1055- return rewriter.getIndexAttr (dim);
1056- }));
1054+ SmallVector<OpFoldResult> sizes = llvm::to_vector (
1055+ llvm::map_range (shape, [&](int64_t dim) -> OpFoldResult {
1056+ return rewriter.getIndexAttr (dim);
1057+ }));
10571058
10581059 SmallVector<Value> results;
10591060
@@ -1064,7 +1065,7 @@ struct SplitConverter : public OpConversionPattern<triton::SplitOp> {
10641065 offsets.push_back (rewriter.getIndexAttr (i));
10651066 sizes.push_back (rewriter.getIndexAttr (1 ));
10661067 Value slice = rewriter.create <tensor::ExtractSliceOp>(
1067- loc, resultTensor, input, offsets, sizes, strides);
1068+ loc, resultTensor, input, offsets, sizes, strides);
10681069 results.push_back (slice);
10691070 }
10701071
@@ -1084,24 +1085,26 @@ struct JoinConverter : public OpConversionPattern<triton::JoinOp> {
10841085 auto resultType = cast<RankedTensorType>(op.getResult ().getType ());
10851086
10861087 auto loc = op.getLoc ();
1087- Value result = rewriter.create <tensor::EmptyOp>(loc, resultType.getShape (), resultType.getElementType ());
1088+ Value result = rewriter.create <tensor::EmptyOp>(
1089+ loc, resultType.getShape (), resultType.getElementType ());
10881090
10891091 auto shape = resultType.getShape ();
10901092
10911093 SmallVector<OpFoldResult> offsets (shape.size (), rewriter.getIndexAttr (0 ));
10921094 SmallVector<OpFoldResult> strides (shape.size (), rewriter.getIndexAttr (1 ));
1093- SmallVector<OpFoldResult> sizes =
1094- llvm::to_vector ( llvm::map_range (shape, [&](int64_t dim) -> OpFoldResult {
1095- return rewriter.getIndexAttr (dim);
1096- }));
1095+ SmallVector<OpFoldResult> sizes = llvm::to_vector (
1096+ llvm::map_range (shape, [&](int64_t dim) -> OpFoldResult {
1097+ return rewriter.getIndexAttr (dim);
1098+ }));
10971099
10981100 for (int i = 0 ; i < 2 ; ++i) {
10991101 offsets.pop_back ();
11001102 sizes.pop_back ();
11011103
11021104 offsets.push_back (rewriter.getIndexAttr (i));
11031105 sizes.push_back (rewriter.getIndexAttr (1 ));
1104- result = rewriter.create <tensor::InsertSliceOp>(loc, inputs[i], result, offsets, sizes, strides);
1106+ result = rewriter.create <tensor::InsertSliceOp>(loc, inputs[i], result,
1107+ offsets, sizes, strides);
11051108 }
11061109
11071110 rewriter.replaceOp (op, result);
@@ -1118,7 +1121,8 @@ struct MulHiUIOpConverter : public OpConversionPattern<triton::MulhiUIOp> {
11181121 ConversionPatternRewriter &rewriter) const override {
11191122 Location loc = op.getLoc ();
11201123
1121- auto mulResult = rewriter.create <arith::MulUIExtendedOp>(loc, adaptor.getOperands ());
1124+ auto mulResult =
1125+ rewriter.create <arith::MulUIExtendedOp>(loc, adaptor.getOperands ());
11221126 rewriter.replaceOp (op, mulResult.getHigh ());
11231127
11241128 return success ();
@@ -1131,29 +1135,29 @@ struct MatmulConverter : public OpConversionPattern<triton::DotOp> {
11311135 // true means tensor elements are zeros
11321136 // false means not zero or it cannot be determined
11331137 bool isZeroTensor (Value &v, bool integers) const {
1134- if (auto splatOp = v.getDefiningOp <triton::SplatOp>()) {
1135- if (auto constOp = splatOp.getSrc ().getDefiningOp <arith::ConstantOp>()) {
1136- if (auto val = dyn_cast<FloatAttr>(constOp.getValue ())) {
1137- return val.getValueAsDouble () == 0 .;
1138- }
1139- if (auto val = dyn_cast<IntegerAttr>(constOp.getValue ())) {
1140- return val.getValue () == 0 ;
1141- }
1138+ if (auto splatOp = v.getDefiningOp <triton::SplatOp>()) {
1139+ if (auto constOp = splatOp.getSrc ().getDefiningOp <arith::ConstantOp>()) {
1140+ if (auto val = dyn_cast<FloatAttr>(constOp.getValue ())) {
1141+ return val.getValueAsDouble () == 0 .;
1142+ }
1143+ if (auto val = dyn_cast<IntegerAttr>(constOp.getValue ())) {
1144+ return val.getValue () == 0 ;
11421145 }
1143- return false ;
11441146 }
1147+ return false ;
1148+ }
11451149
1146- if (auto constOp = v.getDefiningOp <arith::ConstantOp>()) {
1147- if (auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue ())) {
1148- if (denseAttr.isSplat ()) {
1149- if (integers)
1150- return denseAttr.getSplatValue <APInt>().isZero ();
1151- return denseAttr.getSplatValue <APFloat>().isZero ();
1152- }
1150+ if (auto constOp = v.getDefiningOp <arith::ConstantOp>()) {
1151+ if (auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue ())) {
1152+ if (denseAttr.isSplat ()) {
1153+ if (integers)
1154+ return denseAttr.getSplatValue <APInt>().isZero ();
1155+ return denseAttr.getSplatValue <APFloat>().isZero ();
11531156 }
11541157 }
1158+ }
11551159
1156- return false ;
1160+ return false ;
11571161 }
11581162
11591163 LogicalResult
@@ -1170,9 +1174,10 @@ struct MatmulConverter : public OpConversionPattern<triton::DotOp> {
11701174 bool skipC = isZeroTensor (opc, integers);
11711175 auto init =
11721176 rewriter.create <tensor::EmptyOp>(loc, dstType.getShape (), elementType);
1173- TypedAttr constantAttr = integers ?
1174- static_cast <TypedAttr>(rewriter.getIntegerAttr (elementType, 0 )) :
1175- static_cast <TypedAttr>(rewriter.getFloatAttr (elementType, 0 ));
1177+ TypedAttr constantAttr =
1178+ integers
1179+ ? static_cast <TypedAttr>(rewriter.getIntegerAttr (elementType, 0 ))
1180+ : static_cast <TypedAttr>(rewriter.getFloatAttr (elementType, 0 ));
11761181
11771182 auto zero = rewriter.create <mlir::arith::ConstantOp>(
11781183 op.getLoc (), elementType, constantAttr);
@@ -1211,10 +1216,10 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
12111216
12121217 bool isReductionOpSupported (Operation *redOp) const {
12131218 return isa<arith::AddFOp, arith::AddIOp, arith::AndIOp, arith::MaximumFOp,
1214- arith::MulFOp, arith::MulIOp, arith::MaxNumFOp, arith::MinimumFOp,
1215- arith::MinNumFOp , arith::MinSIOp , arith::MinUIOp, arith::MaxSIOp ,
1216- arith::MaxUIOp , arith::OrIOp , arith::XOrIOp>(
1217- redOp);
1219+ arith::MulFOp, arith::MulIOp, arith::MaxNumFOp,
1220+ arith::MinimumFOp , arith::MinNumFOp , arith::MinSIOp ,
1221+ arith::MinUIOp , arith::MaxSIOp , arith::MaxUIOp, arith::OrIOp,
1222+ arith::XOrIOp>( redOp);
12181223 }
12191224
12201225 arith::ConstantOp getRedBaseConstOp (ConversionPatternRewriter &rewriter,
@@ -1250,15 +1255,13 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
12501255 return rewriter.getIntegerAttr (constantType,
12511256 llvm::minIntN (bitWidth));
12521257 })
1253- .Case <arith::MaxUIOp, arith::XOrIOp>([&](auto ) {
1254- return rewriter.getIntegerAttr (constantType, 0 );
1255- })
1258+ .Case <arith::MaxUIOp, arith::XOrIOp>(
1259+ [&](auto ) { return rewriter.getIntegerAttr (constantType, 0 ); })
12561260 .Case ([&](arith::MulFOp) {
12571261 return rewriter.getFloatAttr (constantType, 1 .f );
12581262 })
1259- .Case <arith::MulIOp, arith::AndIOp>([&](auto ) {
1260- return rewriter.getIntegerAttr (constantType, 1 );
1261- })
1263+ .Case <arith::MulIOp, arith::AndIOp>(
1264+ [&](auto ) { return rewriter.getIntegerAttr (constantType, 1 ); })
12621265 .Case ([&](arith::OrIOp) {
12631266 return rewriter.getIntegerAttr (constantType, 0 );
12641267 })
@@ -1274,7 +1277,7 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
12741277
12751278 bool requiresF32Conversion (const Type elemType, Operation *redOp) const {
12761279 unsigned width =
1277- cast<FloatType>(Float32Type::get (elemType.getContext ())).getWidth ();
1280+ cast<FloatType>(Float32Type::get (elemType.getContext ())).getWidth ();
12781281 return isa<FloatType>(elemType) &&
12791282 elemType.getIntOrFloatBitWidth () < width &&
12801283 isa<arith::AddFOp>(redOp);
@@ -1995,10 +1998,10 @@ class AddPtrConverter : public OpConversionPattern<triton::AddPtrOp> {
19951998 op, op->getResultTypes (), op->getOperands (), outputs, indexingMaps,
19961999 iteratorTypes,
19972000 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
1998- auto resultTypes = llvm::map_to_vector (
1999- op->getResultTypes (), [](Type type) {
2000- return cast<TensorType>(type).getElementType ();
2001- });
2001+ auto resultTypes =
2002+ llvm::map_to_vector ( op->getResultTypes (), [](Type type) {
2003+ return cast<TensorType>(type).getElementType ();
2004+ });
20022005 auto *scalarOp =
20032006 builder.create (loc, op->getName ().getIdentifier (),
20042007 regionArgs.take_front (op->getNumOperands ()),
@@ -2023,7 +2026,8 @@ class TensorOpConverter : public OpConversionPattern<OpType> {
20232026 LogicalResult
20242027 matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
20252028 ConversionPatternRewriter &rewriter) const override {
2026- auto resultTensorType = dyn_cast<RankedTensorType>(op.getResult ().getType ());
2029+ auto resultTensorType =
2030+ dyn_cast<RankedTensorType>(op.getResult ().getType ());
20272031 if (!resultTensorType) {
20282032 return failure ();
20292033 }
@@ -2040,10 +2044,10 @@ class TensorOpConverter : public OpConversionPattern<OpType> {
20402044 op, op->getResultTypes (), op->getOperands (), outputs, indexingMaps,
20412045 iteratorTypes,
20422046 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
2043- auto resultTypes = llvm::map_to_vector (
2044- op->getResultTypes (), [](Type type) {
2045- return cast<TensorType>(type).getElementType ();
2046- });
2047+ auto resultTypes =
2048+ llvm::map_to_vector ( op->getResultTypes (), [](Type type) {
2049+ return cast<TensorType>(type).getElementType ();
2050+ });
20472051 auto *scalarOp =
20482052 builder.create (loc, op->getName ().getIdentifier (),
20492053 regionArgs.take_front (op->getNumOperands ()),
@@ -2078,9 +2082,10 @@ class StorePtrToLinalgConverter : public OpConversionPattern<triton::StoreOp> {
20782082 op, op->getResultTypes (), op->getOperands (), outputs, indexingMaps,
20792083 iteratorTypes,
20802084 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
2081- auto resultTypes = llvm::map_to_vector (op->getResultTypes (), [](Type type) {
2082- return cast<TensorType>(type).getElementType ();
2083- });
2085+ auto resultTypes =
2086+ llvm::map_to_vector (op->getResultTypes (), [](Type type) {
2087+ return cast<TensorType>(type).getElementType ();
2088+ });
20842089 auto *scalarOp =
20852090 builder.create (loc, op->getName ().getIdentifier (),
20862091 regionArgs.take_front (op->getNumOperands ()),
0 commit comments