Skip to content

Commit 966d6f1

Browse files
committed
Cover Y axis
1 parent 5f06efc commit 966d6f1

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

third_party/intel/lib/Analysis/Utility.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) {
196196
if (!conversion)
197197
return false;
198198

199+
llvm::errs() << conversion << "\n";
200+
199201
// Expected conversion is:
200202
// - register=1 -> (0, 1)
201203
// ...

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
154154
// Intermediate reductions
155155
static constexpr int finalElementwiseReductionAxis = 0;
156156
static constexpr int finalWarpsReductionAxis = 1;
157-
static constexpr int repCountReshapedAxis = 2;
157+
static constexpr int innerElementwiseReductionAxis = 2;
158+
static constexpr int outerElementwiseReductionAxis = 4;
158159

159160
LogicalResult matchAndRewrite(ReduceOp op,
160161
PatternRewriter &rewriter) const final {
@@ -187,16 +188,10 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
187188
0)
188189
return failure();
189190

190-
// The layout should cover the tensor shape.
191-
ArrayRef<int64_t> shape = type.getShape();
192-
if ( // X axis condition
193-
encoding.getExecutionSize() * encoding.getRepCluster()[1] *
194-
encoding.getWarpsPerCTA()[1] !=
195-
shape[1] ||
196-
// Y axis conditions
197-
encoding.getRepeatCount() * encoding.getRepCluster()[0] *
198-
encoding.getWarpsPerCTA()[0] !=
199-
shape[0])
191+
// The encoding should cover the Y axis.
192+
if (encoding.getRepeatCount() * encoding.getRepCluster()[0] *
193+
encoding.getWarpsPerCTA()[0] !=
194+
type.getShape()[0])
200195
return failure();
201196

202197
LLVM_DEBUG(llvm::dbgs() << "Optimizing reduction: " << op << "\n");
@@ -206,11 +201,10 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
206201
LLVM_DEBUG(llvm::dbgs()
207202
<< "Reshaped for elementwise reduction: " << operand << "\n");
208203

209-
operand = performElementWiseReductionWithinRepCount(op, rewriter, operand);
204+
operand = performInitialElementWiseReductions(op, rewriter, operand);
210205

211-
LLVM_DEBUG(llvm::dbgs()
212-
<< "Performed elementwise reduction within repCount: " << operand
213-
<< "\n");
206+
LLVM_DEBUG(llvm::dbgs() << "Performed initial elementwise reductions: "
207+
<< operand << "\n");
214208

215209
operand = convertLayoutForFinalReduction(op, rewriter, operand, encoding);
216210

@@ -256,26 +250,37 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
256250
auto oldType = cast<RankedTensorType>(val.getType());
257251
ArrayRef<int64_t> oldShape = oldType.getShape();
258252

259-
constexpr size_t rank = 6;
253+
constexpr size_t rank = 7;
260254
std::array<int64_t, rank> shape{
261-
dpasEncoding.getExecutionSize(), dpasEncoding.getRepeatCount(),
262-
dpasEncoding.getRepCluster()[1], dpasEncoding.getRepCluster()[0],
263-
dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]};
264-
std::array<unsigned, rank> sizePerThread{1,
265-
dpasEncoding.getRepeatCount(),
266-
dpasEncoding.getRepCluster()[1],
267-
dpasEncoding.getRepCluster()[0],
268-
1,
269-
1};
255+
dpasEncoding.getExecutionSize(),
256+
dpasEncoding.getRepeatCount(),
257+
dpasEncoding.getRepCluster()[1],
258+
dpasEncoding.getRepCluster()[0],
259+
dpasEncoding.getWarpsPerCTA()[1],
260+
oldShape[1] /
261+
(dpasEncoding.getExecutionSize() * dpasEncoding.getRepCluster()[1] *
262+
dpasEncoding.getWarpsPerCTA()[1]),
263+
dpasEncoding.getWarpsPerCTA()[0]};
264+
std::array<unsigned, rank> sizePerThread{
265+
1,
266+
dpasEncoding.getRepeatCount(),
267+
dpasEncoding.getRepCluster()[1],
268+
dpasEncoding.getRepCluster()[0],
269+
1,
270+
static_cast<unsigned>(oldShape[1]) /
271+
(dpasEncoding.getExecutionSize() * dpasEncoding.getRepCluster()[1] *
272+
dpasEncoding.getWarpsPerCTA()[1]),
273+
1};
270274
std::array<unsigned, rank> threadsPerWarp{
271-
dpasEncoding.getExecutionSize(), 1, 1, 1, 1, 1};
275+
dpasEncoding.getExecutionSize(), 1, 1, 1, 1, 1, 1};
272276
std::array<unsigned, rank> warpsPerCTA{1,
273277
1,
274278
1,
275279
1,
276280
dpasEncoding.getWarpsPerCTA()[1],
281+
1,
277282
dpasEncoding.getWarpsPerCTA()[0]};
278-
constexpr std::array<unsigned, rank> order{0, 1, 2, 3, 4, 5};
283+
constexpr std::array<unsigned, rank> order{0, 1, 2, 3, 4, 5, 6};
279284
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
280285

281286
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
@@ -305,10 +310,14 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
305310
return newOp.getResult().front();
306311
}
307312

308-
Value performElementWiseReductionWithinRepCount(ReduceOp op,
309-
PatternRewriter &rewriter,
310-
Value val) const {
311-
return performReduction(op, rewriter, val, /*axis=*/repCountReshapedAxis);
313+
Value performInitialElementWiseReductions(ReduceOp op,
314+
PatternRewriter &rewriter,
315+
Value val) const {
316+
return performReduction(
317+
op, rewriter,
318+
performReduction(op, rewriter, val,
319+
/*axis=*/innerElementwiseReductionAxis),
320+
outerElementwiseReductionAxis);
312321
}
313322

314323
Value convertLayoutForFinalReduction(ReduceOp op, PatternRewriter &rewriter,

0 commit comments

Comments
 (0)