Skip to content

Commit 1338c49

Browse files
authored
Fix format issue (#303)
Fix format issues in a few source files
1 parent cb466ef commit 1338c49

File tree

1 file changed

+77
-72
lines changed

1 file changed

+77
-72
lines changed

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 77 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ static SmallVector<utils::IteratorType> getNParallelLoopsAttrs(unsigned n) {
114114
}
115115

116116
// if order is empty, transpose the last two dimensions
117-
// otherwise, use the provided order.
117+
// otherwise, use the provided order.
118118
// The order must be a permutation of the source rank.
119119
static Value getTransposedValue(Value source, const Location loc,
120120
ConversionPatternRewriter &rewriter,
@@ -521,9 +521,6 @@ struct StoreConverter : public OpConversionPattern<triton::StoreOp> {
521521
if (!isa<ShapedType>(val.getType())) {
522522
auto sMemRef =
523523
PtrAnalysis::getScalarMemRef(op.getPtr(), ptr, loc, rewriter);
524-
auto index =
525-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
526-
.getResult();
527524
auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext());
528525
rewriter.create<affine::AffineStoreOp>(loc, val, sMemRef, zeroMap,
529526
std::nullopt);
@@ -891,7 +888,8 @@ struct CallConverter : public OpConversionPattern<triton::CallOp> {
891888
ConversionPatternRewriter &rewriter) const override {
892889
SmallVector<Value> args = adaptor.getOperands();
893890

894-
// We need to pass extra arguments added by addProgramInfo which are num_programs and program_ids
891+
// We need to pass extra arguments added by addProgramInfo which are
892+
// num_programs and program_ids
895893
if (FuncOp parentFunc = op->getParentOfType<triton::FuncOp>()) {
896894
SymbolRefAttr calleeAttr = op.getCalleeAttr();
897895
StringRef calleeName = calleeAttr.getRootReference();
@@ -914,12 +912,12 @@ struct CallConverter : public OpConversionPattern<triton::CallOp> {
914912
}
915913
}
916914

917-
auto call = rewriter.create<func::CallOp>(
918-
op.getLoc(), op.getCallee(), op.getResultTypes(), args);
915+
auto call = rewriter.create<func::CallOp>(op.getLoc(), op.getCallee(),
916+
op.getResultTypes(), args);
919917

920918
if (!call) {
921-
op.emitError("Failed to create func::CallOp");
922-
return failure();
919+
op.emitError("Failed to create func::CallOp");
920+
return failure();
923921
}
924922

925923
rewriter.replaceOp(op, call);
@@ -949,15 +947,17 @@ struct FpToFpConverter : public OpConversionPattern<triton::FpToFpOp> {
949947
auto resultWidth = getBitWidth(resultType);
950948

951949
assert(operandWidth.has_value() && resultWidth.has_value() &&
952-
"Not a float-like operand or result");
950+
"Not a float-like operand or result");
953951

954952
if (operandWidth.value() > resultWidth.value()) {
955-
Value truncatedValue = rewriter.create<arith::TruncFOp>(op.getLoc(), resultType, op.getOperand());
953+
Value truncatedValue = rewriter.create<arith::TruncFOp>(
954+
op.getLoc(), resultType, op.getOperand());
956955
rewriter.replaceOp(op, truncatedValue);
957956
return success();
958957
}
959958

960-
Value extendedValue = rewriter.create<arith::ExtFOp>(op.getLoc(), resultType, op.getOperand());
959+
Value extendedValue = rewriter.create<arith::ExtFOp>(
960+
op.getLoc(), resultType, op.getOperand());
961961
rewriter.replaceOp(op, extendedValue);
962962

963963
return success();
@@ -991,14 +991,15 @@ struct ClampConverter : public OpConversionPattern<triton::ClampFOp> {
991991
}
992992
};
993993

994-
struct PreciseSqrtConverter : public OpConversionPattern<triton::PreciseSqrtOp> {
994+
struct PreciseSqrtConverter
995+
: public OpConversionPattern<triton::PreciseSqrtOp> {
995996
using OpConversionPattern<triton::PreciseSqrtOp>::OpConversionPattern;
996997

997998
LogicalResult
998999
matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor,
9991000
ConversionPatternRewriter &rewriter) const override {
1000-
auto replacement = rewriter.create<math::SqrtOp>(
1001-
op.getLoc(), adaptor.getOperands());
1001+
auto replacement =
1002+
rewriter.create<math::SqrtOp>(op.getLoc(), adaptor.getOperands());
10021003

10031004
rewriter.replaceOp(op, replacement);
10041005
return success();
@@ -1011,8 +1012,8 @@ struct PreciseDivConverter : public OpConversionPattern<triton::PreciseDivFOp> {
10111012
LogicalResult
10121013
matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor,
10131014
ConversionPatternRewriter &rewriter) const override {
1014-
auto replacement = rewriter.create<arith::DivFOp>(
1015-
op.getLoc(), adaptor.getOperands());
1015+
auto replacement =
1016+
rewriter.create<arith::DivFOp>(op.getLoc(), adaptor.getOperands());
10161017

10171018
rewriter.replaceOp(op, replacement);
10181019
return success();
@@ -1050,10 +1051,10 @@ struct SplitConverter : public OpConversionPattern<triton::SplitOp> {
10501051

10511052
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
10521053
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
1053-
SmallVector<OpFoldResult> sizes =
1054-
llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult {
1055-
return rewriter.getIndexAttr(dim);
1056-
}));
1054+
SmallVector<OpFoldResult> sizes = llvm::to_vector(
1055+
llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult {
1056+
return rewriter.getIndexAttr(dim);
1057+
}));
10571058

10581059
SmallVector<Value> results;
10591060

@@ -1064,7 +1065,7 @@ struct SplitConverter : public OpConversionPattern<triton::SplitOp> {
10641065
offsets.push_back(rewriter.getIndexAttr(i));
10651066
sizes.push_back(rewriter.getIndexAttr(1));
10661067
Value slice = rewriter.create<tensor::ExtractSliceOp>(
1067-
loc, resultTensor, input, offsets, sizes, strides);
1068+
loc, resultTensor, input, offsets, sizes, strides);
10681069
results.push_back(slice);
10691070
}
10701071

@@ -1084,24 +1085,26 @@ struct JoinConverter : public OpConversionPattern<triton::JoinOp> {
10841085
auto resultType = cast<RankedTensorType>(op.getResult().getType());
10851086

10861087
auto loc = op.getLoc();
1087-
Value result = rewriter.create<tensor::EmptyOp>(loc, resultType.getShape(), resultType.getElementType());
1088+
Value result = rewriter.create<tensor::EmptyOp>(
1089+
loc, resultType.getShape(), resultType.getElementType());
10881090

10891091
auto shape = resultType.getShape();
10901092

10911093
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
10921094
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
1093-
SmallVector<OpFoldResult> sizes =
1094-
llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult {
1095-
return rewriter.getIndexAttr(dim);
1096-
}));
1095+
SmallVector<OpFoldResult> sizes = llvm::to_vector(
1096+
llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult {
1097+
return rewriter.getIndexAttr(dim);
1098+
}));
10971099

10981100
for (int i = 0; i < 2; ++i) {
10991101
offsets.pop_back();
11001102
sizes.pop_back();
11011103

11021104
offsets.push_back(rewriter.getIndexAttr(i));
11031105
sizes.push_back(rewriter.getIndexAttr(1));
1104-
result = rewriter.create<tensor::InsertSliceOp>(loc, inputs[i], result, offsets, sizes, strides);
1106+
result = rewriter.create<tensor::InsertSliceOp>(loc, inputs[i], result,
1107+
offsets, sizes, strides);
11051108
}
11061109

11071110
rewriter.replaceOp(op, result);
@@ -1118,7 +1121,8 @@ struct MulHiUIOpConverter : public OpConversionPattern<triton::MulhiUIOp> {
11181121
ConversionPatternRewriter &rewriter) const override {
11191122
Location loc = op.getLoc();
11201123

1121-
auto mulResult = rewriter.create<arith::MulUIExtendedOp>(loc, adaptor.getOperands());
1124+
auto mulResult =
1125+
rewriter.create<arith::MulUIExtendedOp>(loc, adaptor.getOperands());
11221126
rewriter.replaceOp(op, mulResult.getHigh());
11231127

11241128
return success();
@@ -1131,29 +1135,29 @@ struct MatmulConverter : public OpConversionPattern<triton::DotOp> {
11311135
// true means tensor elements are zeros
11321136
// false means not zero or it cannot be determined
11331137
bool isZeroTensor(Value &v, bool integers) const {
1134-
if (auto splatOp = v.getDefiningOp<triton::SplatOp>()) {
1135-
if (auto constOp = splatOp.getSrc().getDefiningOp<arith::ConstantOp>()) {
1136-
if (auto val = dyn_cast<FloatAttr>(constOp.getValue())) {
1137-
return val.getValueAsDouble() == 0.;
1138-
}
1139-
if (auto val = dyn_cast<IntegerAttr>(constOp.getValue())) {
1140-
return val.getValue() == 0;
1141-
}
1138+
if (auto splatOp = v.getDefiningOp<triton::SplatOp>()) {
1139+
if (auto constOp = splatOp.getSrc().getDefiningOp<arith::ConstantOp>()) {
1140+
if (auto val = dyn_cast<FloatAttr>(constOp.getValue())) {
1141+
return val.getValueAsDouble() == 0.;
1142+
}
1143+
if (auto val = dyn_cast<IntegerAttr>(constOp.getValue())) {
1144+
return val.getValue() == 0;
11421145
}
1143-
return false;
11441146
}
1147+
return false;
1148+
}
11451149

1146-
if (auto constOp = v.getDefiningOp<arith::ConstantOp>()) {
1147-
if (auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue())) {
1148-
if (denseAttr.isSplat()) {
1149-
if (integers)
1150-
return denseAttr.getSplatValue<APInt>().isZero();
1151-
return denseAttr.getSplatValue<APFloat>().isZero();
1152-
}
1150+
if (auto constOp = v.getDefiningOp<arith::ConstantOp>()) {
1151+
if (auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue())) {
1152+
if (denseAttr.isSplat()) {
1153+
if (integers)
1154+
return denseAttr.getSplatValue<APInt>().isZero();
1155+
return denseAttr.getSplatValue<APFloat>().isZero();
11531156
}
11541157
}
1158+
}
11551159

1156-
return false;
1160+
return false;
11571161
}
11581162

11591163
LogicalResult
@@ -1170,9 +1174,10 @@ struct MatmulConverter : public OpConversionPattern<triton::DotOp> {
11701174
bool skipC = isZeroTensor(opc, integers);
11711175
auto init =
11721176
rewriter.create<tensor::EmptyOp>(loc, dstType.getShape(), elementType);
1173-
TypedAttr constantAttr = integers ?
1174-
static_cast<TypedAttr>(rewriter.getIntegerAttr(elementType, 0)) :
1175-
static_cast<TypedAttr>(rewriter.getFloatAttr(elementType, 0));
1177+
TypedAttr constantAttr =
1178+
integers
1179+
? static_cast<TypedAttr>(rewriter.getIntegerAttr(elementType, 0))
1180+
: static_cast<TypedAttr>(rewriter.getFloatAttr(elementType, 0));
11761181

11771182
auto zero = rewriter.create<mlir::arith::ConstantOp>(
11781183
op.getLoc(), elementType, constantAttr);
@@ -1211,10 +1216,10 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
12111216

12121217
bool isReductionOpSupported(Operation *redOp) const {
12131218
return isa<arith::AddFOp, arith::AddIOp, arith::AndIOp, arith::MaximumFOp,
1214-
arith::MulFOp, arith::MulIOp, arith::MaxNumFOp, arith::MinimumFOp,
1215-
arith::MinNumFOp, arith::MinSIOp, arith::MinUIOp, arith::MaxSIOp,
1216-
arith::MaxUIOp, arith::OrIOp, arith::XOrIOp>(
1217-
redOp);
1219+
arith::MulFOp, arith::MulIOp, arith::MaxNumFOp,
1220+
arith::MinimumFOp, arith::MinNumFOp, arith::MinSIOp,
1221+
arith::MinUIOp, arith::MaxSIOp, arith::MaxUIOp, arith::OrIOp,
1222+
arith::XOrIOp>(redOp);
12181223
}
12191224

12201225
arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter,
@@ -1250,15 +1255,13 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
12501255
return rewriter.getIntegerAttr(constantType,
12511256
llvm::minIntN(bitWidth));
12521257
})
1253-
.Case<arith::MaxUIOp, arith::XOrIOp>([&](auto) {
1254-
return rewriter.getIntegerAttr(constantType, 0);
1255-
})
1258+
.Case<arith::MaxUIOp, arith::XOrIOp>(
1259+
[&](auto) { return rewriter.getIntegerAttr(constantType, 0); })
12561260
.Case([&](arith::MulFOp) {
12571261
return rewriter.getFloatAttr(constantType, 1.f);
12581262
})
1259-
.Case<arith::MulIOp, arith::AndIOp>([&](auto) {
1260-
return rewriter.getIntegerAttr(constantType, 1);
1261-
})
1263+
.Case<arith::MulIOp, arith::AndIOp>(
1264+
[&](auto) { return rewriter.getIntegerAttr(constantType, 1); })
12621265
.Case([&](arith::OrIOp) {
12631266
return rewriter.getIntegerAttr(constantType, 0);
12641267
})
@@ -1274,7 +1277,7 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
12741277

12751278
bool requiresF32Conversion(const Type elemType, Operation *redOp) const {
12761279
unsigned width =
1277-
cast<FloatType>(Float32Type::get(elemType.getContext())).getWidth();
1280+
cast<FloatType>(Float32Type::get(elemType.getContext())).getWidth();
12781281
return isa<FloatType>(elemType) &&
12791282
elemType.getIntOrFloatBitWidth() < width &&
12801283
isa<arith::AddFOp>(redOp);
@@ -1995,10 +1998,10 @@ class AddPtrConverter : public OpConversionPattern<triton::AddPtrOp> {
19951998
op, op->getResultTypes(), op->getOperands(), outputs, indexingMaps,
19961999
iteratorTypes,
19972000
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
1998-
auto resultTypes = llvm::map_to_vector(
1999-
op->getResultTypes(), [](Type type) {
2000-
return cast<TensorType>(type).getElementType();
2001-
});
2001+
auto resultTypes =
2002+
llvm::map_to_vector(op->getResultTypes(), [](Type type) {
2003+
return cast<TensorType>(type).getElementType();
2004+
});
20022005
auto *scalarOp =
20032006
builder.create(loc, op->getName().getIdentifier(),
20042007
regionArgs.take_front(op->getNumOperands()),
@@ -2023,7 +2026,8 @@ class TensorOpConverter : public OpConversionPattern<OpType> {
20232026
LogicalResult
20242027
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
20252028
ConversionPatternRewriter &rewriter) const override {
2026-
auto resultTensorType = dyn_cast<RankedTensorType>(op.getResult().getType());
2029+
auto resultTensorType =
2030+
dyn_cast<RankedTensorType>(op.getResult().getType());
20272031
if (!resultTensorType) {
20282032
return failure();
20292033
}
@@ -2040,10 +2044,10 @@ class TensorOpConverter : public OpConversionPattern<OpType> {
20402044
op, op->getResultTypes(), op->getOperands(), outputs, indexingMaps,
20412045
iteratorTypes,
20422046
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
2043-
auto resultTypes = llvm::map_to_vector(
2044-
op->getResultTypes(), [](Type type) {
2045-
return cast<TensorType>(type).getElementType();
2046-
});
2047+
auto resultTypes =
2048+
llvm::map_to_vector(op->getResultTypes(), [](Type type) {
2049+
return cast<TensorType>(type).getElementType();
2050+
});
20472051
auto *scalarOp =
20482052
builder.create(loc, op->getName().getIdentifier(),
20492053
regionArgs.take_front(op->getNumOperands()),
@@ -2078,9 +2082,10 @@ class StorePtrToLinalgConverter : public OpConversionPattern<triton::StoreOp> {
20782082
op, op->getResultTypes(), op->getOperands(), outputs, indexingMaps,
20792083
iteratorTypes,
20802084
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
2081-
auto resultTypes = llvm::map_to_vector(op->getResultTypes(), [](Type type) {
2082-
return cast<TensorType>(type).getElementType();
2083-
});
2085+
auto resultTypes =
2086+
llvm::map_to_vector(op->getResultTypes(), [](Type type) {
2087+
return cast<TensorType>(type).getElementType();
2088+
});
20842089
auto *scalarOp =
20852090
builder.create(loc, op->getName().getIdentifier(),
20862091
regionArgs.take_front(op->getNumOperands()),

0 commit comments

Comments
 (0)