Skip to content

Commit d9fd9c5

Browse files
authored
[Triton] Remove LLVM dialect handling from AxisInfo (#5690)
There was a FIXME left here to clean this up that's almost 2 years old.
1 parent d47e314 commit d9fd9c5

File tree

1 file changed

+9
-21
lines changed

1 file changed

+9
-21
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "mlir/Analysis/DataFlowFramework.h"
2-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
32
#include "llvm/Support/Debug.h"
43
#include "llvm/Support/raw_ostream.h"
54

@@ -232,13 +231,13 @@ class MakeRangeOpAxisInfoVisitor final
232231
}
233232
};
234233

235-
template <typename OpTy>
236-
class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
234+
class ConstantOpAxisInfoVisitor final
235+
: public AxisInfoVisitorImpl<arith::ConstantOp> {
237236
public:
238-
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
237+
using AxisInfoVisitorImpl::AxisInfoVisitorImpl;
239238

240239
AxisInfo
241-
getAxisInfo(OpTy op,
240+
getAxisInfo(arith::ConstantOp op,
242241
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
243242
auto intAttr = dyn_cast<IntegerAttr>(op.getValue());
244243
auto boolAttr = dyn_cast<BoolAttr>(op.getValue());
@@ -323,8 +322,7 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
323322
const AxisInfo &rhs) override {
324323
if (lhs.getConstantValue().has_value() &&
325324
rhs.getConstantValue().has_value()) {
326-
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
327-
std::is_same_v<OpTy, LLVM::AddOp>) {
325+
if constexpr (std::is_same_v<OpTy, arith::AddIOp>) {
328326
return {lhs.getConstantValue().value() +
329327
rhs.getConstantValue().value()};
330328
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
@@ -1013,15 +1011,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10131011
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
10141012
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
10151013
CastOpAxisInfoVisitor<triton::BitcastOp>>();
1016-
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
1017-
// when scf.for supports integer induction variables
10181014
visitors.append<MakeRangeOpAxisInfoVisitor>();
1019-
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
1020-
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
1015+
visitors.append<ConstantOpAxisInfoVisitor>();
10211016
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
10221017
AddSubOpAxisInfoVisitor<arith::AddIOp>,
1023-
AddSubOpAxisInfoVisitor<arith::SubIOp>,
1024-
AddSubOpAxisInfoVisitor<LLVM::AddOp>>();
1018+
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
10251019
visitors.append<MulIOpAxisInfoVisitor>();
10261020
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
10271021
DivOpAxisInfoVisitor<arith::DivUIOp>>();
@@ -1138,17 +1132,11 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
11381132

11391133
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
11401134
Operation *op = blockArg.getOwner()->getParentOp();
1141-
if (auto fun = dyn_cast<FunctionOpInterface>(op))
1142-
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
1143-
&knownContiguity, &knownDivisibility,
1144-
&knownConstancy);
1145-
// llvm codegen check alignment to generate vector load/store
1146-
// would be nice if this wasn't the case
1147-
else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op))
1135+
if (auto fun = dyn_cast<FunctionOpInterface>(op)) {
11481136
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
11491137
&knownContiguity, &knownDivisibility,
11501138
&knownConstancy);
1151-
else if (isa<RegionBranchOpInterface>(op)) {
1139+
} else if (isa<RegionBranchOpInterface>(op)) {
11521140
// scf::ForOp, scf::IfOp, scf::WhileOp
11531141
// Control flow operations are initialized with "unknown" state:
11541142
// the maximum possible divisibility, contiguity, and constancy.

0 commit comments

Comments
 (0)