Skip to content

Commit 811437e

Browse files
authored
[AMD] Support ExtractSliceOp for AxisInfo (#7094)
This commit updates AxisInfo to support backend callbacks to enable recognizing backend ops. One can use `ExtractSliceOp` to slice tensors of pointers to refine `tt.load` or `tt.store`. The `TritonAMDGPUConvertToBufferOpsBase` will fail to perform negativity analysis due to the presence of `ExtractSliceOp` which after rewrites is going to slice tensors of offsets. This PR addresses the issue.
1 parent bb1000f commit 811437e

File tree

16 files changed

+230
-102
lines changed

16 files changed

+230
-102
lines changed

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace mlir {
3333
namespace test {
3434
void registerTestAliasPass();
3535
void registerTestAlignmentPass();
36+
void registerAMDTestAlignmentPass();
3637
void registerTestAllocationPass();
3738
void registerTestMembarPass();
3839
void registerTestAMDGPUMembarPass();
@@ -48,6 +49,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
4849
mlir::triton::nvidia_gpu::registerTritonNvidiaGPUPasses();
4950
mlir::test::registerTestAliasPass();
5051
mlir::test::registerTestAlignmentPass();
52+
mlir::test::registerAMDTestAlignmentPass();
5153
mlir::test::registerTestAllocationPass();
5254
mlir::test::registerTestMembarPass();
5355
mlir::test::registerTestLoopPeelingPass();

include/triton/Analysis/AxisInfo.h

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,49 @@ class AxisInfo {
149149
std::optional<int64_t> constantValue;
150150
};
151151

152+
class AxisInfoVisitor {
153+
public:
154+
AxisInfoVisitor() = default;
155+
virtual ~AxisInfoVisitor() = default;
156+
157+
bool isContiguousDim(const AxisInfo &info, ArrayRef<int64_t> shape, int dim) {
158+
return info.getContiguity(dim) == shape[dim];
159+
}
160+
161+
bool isConstantDim(const AxisInfo &info, ArrayRef<int64_t> shape, int dim) {
162+
return info.getConstancy(dim) == shape[dim];
163+
}
164+
165+
virtual AxisInfo
166+
getAxisInfo(Operation *op,
167+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) = 0;
168+
169+
virtual bool match(Operation *op) = 0;
170+
};
171+
172+
class AxisInfoVisitorList {
173+
public:
174+
template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>>
175+
void append() {
176+
(visitors.emplace_back(std::make_unique<Ts>()), ...);
177+
}
178+
179+
AxisInfo apply(Operation *op,
180+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
181+
for (auto &visitor : visitors)
182+
if (visitor->match(op))
183+
return visitor->getAxisInfo(op, operands);
184+
return AxisInfo();
185+
}
186+
187+
private:
188+
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
189+
};
190+
191+
namespace axisinfo {
192+
using CallbackType = std::function<void(AxisInfoVisitorList &)>;
193+
} // namespace axisinfo
194+
152195
// Module level axis info analysis based on the call graph, assuming that we do
153196
// not have recursive functions.
154197
//
@@ -159,7 +202,8 @@ class AxisInfo {
159202
using AxisInfoMapT = DenseMap<Value, AxisInfo>;
160203
class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
161204
public:
162-
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
205+
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp,
206+
axisinfo::CallbackType callback = nullptr)
163207
: CallGraph<AxisInfoMapT>(moduleOp) {
164208
SmallVector<FunctionOpInterface> funcs;
165209
for (auto root : getRoots()) {
@@ -175,7 +219,7 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
175219
SetVector<FunctionOpInterface> sortedFuncs(funcs.begin(), funcs.end());
176220
SymbolTableCollection symbolTable;
177221
for (auto funcOp : llvm::reverse(sortedFuncs)) {
178-
initialize(funcOp);
222+
initialize(funcOp, callback);
179223
funcOp.walk([&](CallOpInterface callOp) {
180224
auto callee = dyn_cast<FunctionOpInterface>(
181225
callOp.resolveCallableInTable(&symbolTable));
@@ -217,10 +261,10 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
217261
unsigned getMaskAlignment(Value mask);
218262

219263
private:
220-
void initialize(FunctionOpInterface funcOp);
264+
void initialize(FunctionOpInterface funcOp,
265+
axisinfo::CallbackType callback = nullptr);
221266
void update(CallOpInterface callOp, FunctionOpInterface funcOp);
222267
};
223-
224268
} // namespace mlir::triton
225269

226270
#endif

lib/Analysis/AxisInfo.cpp

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1+
#include "triton/Analysis/AxisInfo.h"
12
#include "mlir/Analysis/DataFlowFramework.h"
23
#include "mlir/Dialect/UB/IR/UBOps.h"
3-
#include "llvm/Support/Debug.h"
4-
#include "llvm/Support/raw_ostream.h"
5-
6-
#include "triton/Analysis/AxisInfo.h"
74
#include "triton/Dialect/Triton/IR/Dialect.h"
85
#include "triton/Dialect/Triton/IR/Utility.h"
96
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
7+
#include "llvm/Support/Debug.h"
8+
#include "llvm/Support/raw_ostream.h"
109

1110
#define DEBUG_TYPE "axis-info"
1211
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -52,28 +51,6 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5251
return lhs * rhs;
5352
}
5453

55-
class AxisInfoVisitor {
56-
public:
57-
AxisInfoVisitor() = default;
58-
virtual ~AxisInfoVisitor() = default;
59-
60-
static bool isContiguousDim(const AxisInfo &info, ArrayRef<int64_t> shape,
61-
int dim) {
62-
return info.getContiguity(dim) == shape[dim];
63-
}
64-
65-
static bool isConstantDim(const AxisInfo &info, ArrayRef<int64_t> shape,
66-
int dim) {
67-
return info.getConstancy(dim) == shape[dim];
68-
}
69-
70-
virtual AxisInfo
71-
getAxisInfo(Operation *op,
72-
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) = 0;
73-
74-
virtual bool match(Operation *op) = 0;
75-
};
76-
7754
// Base class for all operations
7855
template <typename OpTy> class AxisInfoVisitorImpl : public AxisInfoVisitor {
7956
public:
@@ -147,25 +124,6 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
147124
}
148125
};
149126

150-
class AxisInfoVisitorList {
151-
public:
152-
template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>>
153-
void append() {
154-
(visitors.emplace_back(std::make_unique<Ts>()), ...);
155-
}
156-
157-
AxisInfo apply(Operation *op,
158-
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
159-
for (auto &visitor : visitors)
160-
if (visitor->match(op))
161-
return visitor->getAxisInfo(op, operands);
162-
return AxisInfo();
163-
}
164-
165-
private:
166-
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
167-
};
168-
169127
class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
170128
dataflow::Lattice<AxisInfo>> {
171129
private:
@@ -193,7 +151,8 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
193151
}
194152

195153
public:
196-
AxisInfoAnalysis(DataFlowSolver &solver);
154+
AxisInfoAnalysis(DataFlowSolver &solver,
155+
axisinfo::CallbackType callback = nullptr);
197156
using dataflow::SparseForwardDataFlowAnalysis<
198157
dataflow::Lattice<AxisInfo>>::getLatticeElement;
199158
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
@@ -1031,7 +990,8 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
1031990
// AxisInfoAnalysis
1032991
//===----------------------------------------------------------------------===//
1033992

1034-
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
993+
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
994+
axisinfo::CallbackType callback)
1035995
: dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(
1036996
solver) {
1037997
// UnrealizedConversionCast:
@@ -1070,6 +1030,9 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10701030
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
10711031
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
10721032
visitors.append<LoadOpAxisInfoVisitor>();
1033+
1034+
if (callback)
1035+
callback(visitors);
10731036
}
10741037

10751038
LogicalResult AxisInfoAnalysis::visitOperation(
@@ -1339,9 +1302,10 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
13391302
return alignment;
13401303
}
13411304

1342-
void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) {
1305+
void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
1306+
axisinfo::CallbackType callback) {
13431307
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
1344-
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
1308+
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>(callback);
13451309
// Walk pre-order so analysis results can be propagated into nested isolated
13461310
// regions.
13471311
WalkResult result =

test/Analysis/amd/test-alignment.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: triton-opt %s -test-print-amd-alignment -split-input-file -verify-diagnostics=only-expected -o /dev/null
2+
3+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
4+
5+
tt.func public @kernel(%arg0: tensor<256x64xf16, #mma> {tt.contiguity=256 : i32, tt.divisibility=6: i32, tt.constancy=1: i32}) attributes {noinline = false} {
6+
// expeted-remark @below {{contiguity = [128, 32], divisibility = [6, 6], constancy = [1, 1], constant_value = <none>}}
7+
%0 = amdgpu.extract_slice %arg0 [128, 32] : tensor<256x64xf16, #mma> to tensor<128x32xf16, #mma>
8+
tt.return
9+
}

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,3 +650,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
650650
tt.return
651651
}
652652
}
653+
654+
// -----
655+
656+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
657+
658+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
659+
tt.func @extract_slice(%arg0: !tt.ptr<f32>) -> tensor<128x256xf32, #blocked> {
660+
%0 = arith.constant dense<0> : tensor<256x256xi64, #blocked>
661+
%1 = amdgpu.extract_slice %0 [0, 0] : tensor<256x256xi64, #blocked> to tensor<128x256xi64, #blocked>
662+
%2 = arith.trunci %1 : tensor<128x256xi64, #blocked> to tensor<128x256xi32, #blocked>
663+
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x256x!tt.ptr<f32>, #blocked>
664+
%4 = tt.addptr %3, %2 : tensor<128x256x!tt.ptr<f32>, #blocked>, tensor<128x256xi32, #blocked>
665+
%5 = tt.load %4 : tensor<128x256x!tt.ptr<f32>, #blocked>
666+
tt.return %5 : tensor<128x256xf32, #blocked>
667+
}
668+
}
669+
670+
// CHECK-LABEL: tt.func @extract_slice(
671+
// CHECK-SAME: %[[ARG_0:.*]]: !tt.ptr<f32>) -> tensor<128x256xf32, #blocked> {
672+
// CHECK: %[[VAR_0:.*]] = arith.constant dense<0> : tensor<256x256xi64, #blocked>
673+
// CHECK: %[[VAR_1:.*]] = amdgpu.extract_slice %[[VAR_0]] [0, 0] : tensor<256x256xi64, #blocked> to tensor<128x256xi64, #blocked>
674+
// CHECK: %[[VAR_2:.*]] = arith.trunci %[[VAR_1]] : tensor<128x256xi64, #blocked> to tensor<128x256xi32, #blocked>
675+
// CHECK: %[[VAR_3:.*]] = amdgpu.buffer_load %[[ARG_0]][%[[VAR_2]]] : tensor<128x256xf32, #blocked>
676+
// CHECK: tt.return %[[VAR_3]] : tensor<128x256xf32, #blocked>
677+
// CHECK: }

test/include/Analysis/TestAxisInfo.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#pragma once
2+
3+
#include "mlir/IR/Diagnostics.h"
4+
#include "mlir/Pass/Pass.h"
5+
#include "triton/Analysis/AxisInfo.h"
6+
7+
using namespace mlir;
8+
using namespace mlir::triton;
9+
10+
namespace mlir::test {
11+
12+
struct TestAxisInfoPass
13+
: public PassWrapper<TestAxisInfoPass, OperationPass<ModuleOp>> {
14+
15+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
16+
17+
StringRef getArgument() const override { return "test-print-alignment"; }
18+
StringRef getDescription() const final {
19+
return "print the result of the alignment analysis pass";
20+
}
21+
22+
void runOnOperation() override {
23+
Operation *operation = this->getOperation();
24+
ModuleOp moduleOp = cast<ModuleOp>(operation);
25+
auto moduleAxisInfoAnalysis = getAnalysis(moduleOp);
26+
moduleOp.walk([&](FuncOp funcOp) {
27+
funcOp.walk([&](Operation *op) {
28+
if (op->getNumResults() < 1)
29+
return;
30+
for (Value result : op->getResults()) {
31+
InFlightDiagnostic diag = mlir::emitRemark(op->getLoc());
32+
diag << result;
33+
diag << " => ";
34+
auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result);
35+
if (axisInfo) {
36+
std::string str;
37+
llvm::raw_string_ostream os(str);
38+
axisInfo->print(os);
39+
diag << str;
40+
}
41+
}
42+
});
43+
});
44+
}
45+
46+
protected:
47+
virtual ModuleAxisInfoAnalysis getAnalysis(ModuleOp moduleOp) const {
48+
return ModuleAxisInfoAnalysis(moduleOp);
49+
}
50+
};
51+
52+
} // namespace mlir::test

test/lib/Analysis/TestAxisInfo.cpp

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,4 @@
1-
#include "mlir/IR/Diagnostics.h"
2-
#include "mlir/Pass/Pass.h"
3-
#include "triton/Analysis/AxisInfo.h"
4-
5-
using namespace mlir;
6-
using namespace mlir::triton;
7-
8-
namespace {
9-
10-
struct TestAxisInfoPass
11-
: public PassWrapper<TestAxisInfoPass, OperationPass<ModuleOp>> {
12-
13-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
14-
15-
StringRef getArgument() const final { return "test-print-alignment"; }
16-
StringRef getDescription() const final {
17-
return "print the result of the alignment analysis pass";
18-
}
19-
20-
void runOnOperation() override {
21-
Operation *operation = getOperation();
22-
ModuleOp moduleOp = cast<ModuleOp>(operation);
23-
ModuleAxisInfoAnalysis moduleAxisInfoAnalysis(moduleOp);
24-
moduleOp.walk([&](FuncOp funcOp) {
25-
funcOp.walk([&](Operation *op) {
26-
if (op->getNumResults() < 1)
27-
return;
28-
for (Value result : op->getResults()) {
29-
InFlightDiagnostic diag = mlir::emitRemark(op->getLoc());
30-
diag << result;
31-
diag << " => ";
32-
auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result);
33-
if (axisInfo) {
34-
std::string str;
35-
llvm::raw_string_ostream os(str);
36-
axisInfo->print(os);
37-
diag << str;
38-
}
39-
}
40-
});
41-
});
42-
}
43-
};
44-
45-
} // namespace
1+
#include "test/include/Analysis/TestAxisInfo.h"
462

473
namespace mlir {
484
namespace test {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef TRITONAMD_ANALYSIS_AXIS_INFO_EXT_H
2+
#define TRITONAMD_ANALYSIS_AXIS_INFO_EXT_H
3+
4+
#include "include/triton/Analysis/AxisInfo.h"
5+
6+
namespace mlir::triton::AMD {
7+
8+
struct AxisInfoExt {
9+
static void addVisitors(mlir::triton::AxisInfoVisitorList &visitors);
10+
};
11+
12+
class ModuleAxisInfoAnalysis : public mlir::triton::ModuleAxisInfoAnalysis {
13+
public:
14+
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
15+
: mlir::triton::ModuleAxisInfoAnalysis(moduleOp,
16+
AxisInfoExt::addVisitors) {}
17+
};
18+
} // namespace mlir::triton::AMD
19+
20+
#endif

0 commit comments

Comments
 (0)