|
1 | 1 | #include "mlir/Analysis/DataFlowFramework.h" |
2 | | -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
3 | 2 | #include "llvm/Support/Debug.h" |
4 | 3 | #include "llvm/Support/raw_ostream.h" |
5 | 4 |
|
@@ -232,13 +231,13 @@ class MakeRangeOpAxisInfoVisitor final |
232 | 231 | } |
233 | 232 | }; |
234 | 233 |
|
235 | | -template <typename OpTy> |
236 | | -class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> { |
| 234 | +class ConstantOpAxisInfoVisitor final |
| 235 | + : public AxisInfoVisitorImpl<arith::ConstantOp> { |
237 | 236 | public: |
238 | | - using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl; |
| 237 | + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; |
239 | 238 |
|
240 | 239 | AxisInfo |
241 | | - getAxisInfo(OpTy op, |
| 240 | + getAxisInfo(arith::ConstantOp op, |
242 | 241 | ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { |
243 | 242 | auto intAttr = dyn_cast<IntegerAttr>(op.getValue()); |
244 | 243 | auto boolAttr = dyn_cast<BoolAttr>(op.getValue()); |
@@ -323,8 +322,7 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> { |
323 | 322 | const AxisInfo &rhs) override { |
324 | 323 | if (lhs.getConstantValue().has_value() && |
325 | 324 | rhs.getConstantValue().has_value()) { |
326 | | - if constexpr (std::is_same_v<OpTy, arith::AddIOp> || |
327 | | - std::is_same_v<OpTy, LLVM::AddOp>) { |
| 325 | + if constexpr (std::is_same_v<OpTy, arith::AddIOp>) { |
328 | 326 | return {lhs.getConstantValue().value() + |
329 | 327 | rhs.getConstantValue().value()}; |
330 | 328 | } else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) { |
@@ -1013,15 +1011,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) |
1013 | 1011 | CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>, |
1014 | 1012 | CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>, |
1015 | 1013 | CastOpAxisInfoVisitor<triton::BitcastOp>>(); |
1016 | | - // TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp |
1017 | | - // when scf.for supports integer induction variables |
1018 | 1014 | visitors.append<MakeRangeOpAxisInfoVisitor>(); |
1019 | | - visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>, |
1020 | | - ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>(); |
| 1015 | + visitors.append<ConstantOpAxisInfoVisitor>(); |
1021 | 1016 | visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>, |
1022 | 1017 | AddSubOpAxisInfoVisitor<arith::AddIOp>, |
1023 | | - AddSubOpAxisInfoVisitor<arith::SubIOp>, |
1024 | | - AddSubOpAxisInfoVisitor<LLVM::AddOp>>(); |
| 1018 | + AddSubOpAxisInfoVisitor<arith::SubIOp>>(); |
1025 | 1019 | visitors.append<MulIOpAxisInfoVisitor>(); |
1026 | 1020 | visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>, |
1027 | 1021 | DivOpAxisInfoVisitor<arith::DivUIOp>>(); |
@@ -1138,17 +1132,11 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, |
1138 | 1132 |
|
1139 | 1133 | if (blockArg && blockArg.getOwner()->isEntryBlock()) { |
1140 | 1134 | Operation *op = blockArg.getOwner()->getParentOp(); |
1141 | | - if (auto fun = dyn_cast<FunctionOpInterface>(op)) |
1142 | | - initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, |
1143 | | - &knownContiguity, &knownDivisibility, |
1144 | | - &knownConstancy); |
1145 | | - // llvm codegen check alignment to generate vector load/store |
1146 | | - // would be nice if this wasn't the case |
1147 | | - else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op)) |
| 1135 | + if (auto fun = dyn_cast<FunctionOpInterface>(op)) { |
1148 | 1136 | initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, |
1149 | 1137 | &knownContiguity, &knownDivisibility, |
1150 | 1138 | &knownConstancy); |
1151 | | - else if (isa<RegionBranchOpInterface>(op)) { |
| 1139 | + } else if (isa<RegionBranchOpInterface>(op)) { |
1152 | 1140 | // scf::ForOp, scf::IfOp, scf::WhileOp |
1153 | 1141 | // Control flow operations are initialized with "unknown" state: |
1154 | 1142 | // the maximum possible divisibility, contiguity, and constancy. |
|
0 commit comments