Skip to content

Commit 6043bce

Browse files
committed
Test working
1 parent 4b517b9 commit 6043bce

File tree

2 files changed

+76
-108
lines changed

2 files changed

+76
-108
lines changed

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def make_ttgir(mod, metadata, opt, properties):
253253
passes.ttgpuir.add_optimize_dot_operands(pm, True)
254254
intel.passes.ttgpuir.add_optimize_reduction_locality(pm)
255255
intel.passes.ttgpuir.add_optimize_elementwise_parallelism(pm)
256-
#intel.passes.ttgpuir.add_remove_layout_conversions(pm)
256+
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
257257
intel.passes.ttgpuir.add_reduce_data_duplication(pm)
258258
passes.ttgpuir.add_reorder_instructions(pm)
259259
passes.common.add_cse(pm)

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 75 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,9 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
152152
static constexpr int preferredReductionAxis = 1;
153153

154154
// Intermediate reductions
155-
static constexpr int finalEWReductionAxis = 0;
156-
static constexpr int finalWarpsReductionAxis = 2;
155+
static constexpr int finalElementwiseReductionAxis = 0;
156+
static constexpr int finalWarpsReductionAxis = 1;
157157
static constexpr int repCountReshapedAxis = 2;
158-
static constexpr int withinWarpXAxisReshapedAxis = 5;
159158

160159
LogicalResult matchAndRewrite(ReduceOp op,
161160
PatternRewriter &rewriter) const final {
@@ -188,6 +187,18 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
188187
0)
189188
return failure();
190189

190+
// The layout should cover the tensor shape.
191+
ArrayRef<int64_t> shape = type.getShape();
192+
if ( // X axis condition
193+
encoding.getExecutionSize() * encoding.getRepCluster()[1] *
194+
encoding.getWarpsPerCTA()[1] !=
195+
shape[1] ||
196+
// Y axis conditions
197+
encoding.getRepeatCount() * encoding.getRepCluster()[0] *
198+
encoding.getWarpsPerCTA()[0] !=
199+
shape[0])
200+
return failure();
201+
191202
LLVM_DEBUG(llvm::dbgs() << "Optimizing reduction: " << op << "\n");
192203

193204
operand = reshapeForElementWiseReduction(op, rewriter, encoding);
@@ -201,18 +212,12 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
201212
<< "Performed elementwise reduction within repCount: " << operand
202213
<< "\n");
203214

204-
operand = performElementWiseReductionAcrossRepCounts(op, rewriter, operand);
205-
206-
LLVM_DEBUG(llvm::dbgs()
207-
<< "Performed elementwise reduction across repCount: " << operand
208-
<< "\n");
209-
210215
operand = reshapeForFinalReduction(op, rewriter, operand, encoding);
211216

212217
LLVM_DEBUG(llvm::dbgs()
213218
<< "Reshaped for final reduction: " << operand << "\n");
214219

215-
operand = convertLayoutForFinalReduction(op, rewriter, operand);
220+
operand = convertLayoutForFinalReduction(op, rewriter, operand, encoding);
216221

217222
LLVM_DEBUG(llvm::dbgs()
218223
<< "Converted layout for final reduction: " << operand << "\n");
@@ -227,15 +232,15 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
227232
LLVM_DEBUG(llvm::dbgs() << "Final across-warps reduction performed: "
228233
<< operand << "\n");
229234

230-
operand = convertLayoutToOriginalType(op, rewriter, operand, encoding);
235+
operand = reshapeToOriginalType(op, rewriter, operand, encoding);
231236

232237
LLVM_DEBUG(llvm::dbgs()
233-
<< "Converted layout to original type: " << operand << "\n");
238+
<< "Reshaped to original type: " << operand << "\n");
234239

235-
operand = reshapeToOriginalType(op, rewriter, operand);
240+
operand = convertLayoutToOriginalType(op, rewriter, operand);
236241

237242
LLVM_DEBUG(llvm::dbgs()
238-
<< "Reshaped to original type: " << operand << "\n");
243+
<< "Converted layout to original type: " << operand << "\n");
239244

240245
rewriter.replaceOp(op, operand);
241246

@@ -251,44 +256,26 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
251256
auto oldType = cast<RankedTensorType>(val.getType());
252257
ArrayRef<int64_t> oldShape = oldType.getShape();
253258

254-
constexpr size_t rank = 8;
259+
constexpr size_t rank = 6;
255260
std::array<int64_t, rank> shape{
256-
dpasEncoding.getExecutionSize(),
257-
dpasEncoding.getRepeatCount(),
258-
dpasEncoding.getRepCluster()[1],
259-
dpasEncoding.getRepCluster()[0],
260-
dpasEncoding.getWarpsPerCTA()[1],
261-
dpasEncoding.getWarpsPerCTA()[0],
262-
oldShape[1] /
263-
(dpasEncoding.getExecutionSize() * dpasEncoding.getRepCluster()[1] *
264-
dpasEncoding.getWarpsPerCTA()[1]),
265-
oldShape[0] /
266-
(dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] *
267-
dpasEncoding.getWarpsPerCTA()[0])};
268-
std::array<unsigned, rank> sizePerThread{
269-
1,
270-
dpasEncoding.getRepeatCount(),
271-
dpasEncoding.getRepCluster()[1],
272-
dpasEncoding.getRepCluster()[0],
273-
1,
274-
1,
275-
static_cast<unsigned>(oldShape[1]) /
276-
(dpasEncoding.getExecutionSize() * dpasEncoding.getRepCluster()[1] *
277-
dpasEncoding.getWarpsPerCTA()[1]),
278-
static_cast<unsigned>(oldShape[0]) /
279-
(dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] *
280-
dpasEncoding.getWarpsPerCTA()[0])};
261+
dpasEncoding.getExecutionSize(), dpasEncoding.getRepeatCount(),
262+
dpasEncoding.getRepCluster()[1], dpasEncoding.getRepCluster()[0],
263+
dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]};
264+
std::array<unsigned, rank> sizePerThread{1,
265+
dpasEncoding.getRepeatCount(),
266+
dpasEncoding.getRepCluster()[1],
267+
dpasEncoding.getRepCluster()[0],
268+
1,
269+
1};
281270
std::array<unsigned, rank> threadsPerWarp{
282-
dpasEncoding.getExecutionSize(), 1, 1, 1, 1, 1, 1, 1};
271+
dpasEncoding.getExecutionSize(), 1, 1, 1, 1, 1};
283272
std::array<unsigned, rank> warpsPerCTA{1,
284273
1,
285274
1,
286275
1,
287276
dpasEncoding.getWarpsPerCTA()[1],
288-
dpasEncoding.getWarpsPerCTA()[0],
289-
1,
290-
1};
291-
std::array<unsigned, rank> order{0, 1, 2, 3, 4, 5, 6, 7};
277+
dpasEncoding.getWarpsPerCTA()[0]};
278+
constexpr std::array<unsigned, rank> order{0, 1, 2, 3, 4, 5};
292279
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
293280

294281
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
@@ -324,24 +311,21 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
324311
return performReduction(op, rewriter, val, /*axis=*/repCountReshapedAxis);
325312
}
326313

327-
Value performElementWiseReductionAcrossRepCounts(ReduceOp op,
328-
PatternRewriter &rewriter,
329-
Value val) const {
330-
return performReduction(op, rewriter, val,
331-
/*axis=*/withinWarpXAxisReshapedAxis);
332-
}
333-
334314
Value convertLayoutForFinalReduction(ReduceOp op, PatternRewriter &rewriter,
335-
Value val) const {
315+
Value val,
316+
DpasEncodingAttr dpasEncoding) const {
336317
auto oldType = cast<RankedTensorType>(val.getType());
337318
auto oldEncoding = cast<BlockedEncodingAttr>(oldType.getEncoding());
338319
RankedTensorType::Builder type(oldType);
339320

340-
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
341-
SmallVector<unsigned> threadsPerWarp = oldEncoding.getThreadsPerWarp();
342-
343-
std::swap(sizePerThread[0], sizePerThread[1]);
344-
std::swap(threadsPerWarp[0], threadsPerWarp[1]);
321+
constexpr size_t rank = 4;
322+
std::array<unsigned, rank> sizePerThread{
323+
dpasEncoding.getExecutionSize(),
324+
dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] /
325+
dpasEncoding.getExecutionSize(),
326+
1, 1};
327+
std::array<unsigned, rank> threadsPerWarp{
328+
1, dpasEncoding.getExecutionSize(), 1, 1};
345329

346330
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
347331
sizePerThread, threadsPerWarp, oldEncoding.getWarpsPerCTA(),
@@ -359,32 +343,20 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
359343
auto oldType = cast<RankedTensorType>(val.getType());
360344
ArrayRef<int64_t> oldShape = oldType.getShape();
361345

362-
constexpr size_t rank = 6;
363-
std::array<int64_t, rank> shape{dpasEncoding.getExecutionSize(),
364-
dpasEncoding.getExecutionSize(),
365-
dpasEncoding.getRepeatCount() *
366-
dpasEncoding.getRepCluster()[0] /
367-
dpasEncoding.getExecutionSize(),
368-
dpasEncoding.getWarpsPerCTA()[1],
369-
dpasEncoding.getWarpsPerCTA()[0],
370-
oldShape.back()};
371-
std::array<unsigned, rank> sizePerThread{
372-
1,
346+
constexpr size_t rank = 4;
347+
std::array<int64_t, rank> shape{
373348
dpasEncoding.getExecutionSize(),
374-
dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] /
375-
dpasEncoding.getExecutionSize(),
376-
1,
377-
1,
378-
static_cast<unsigned>(oldShape.back())};
379-
std::array<unsigned, rank> threadsPerWarp{
380-
dpasEncoding.getExecutionSize(), 1, 1, 1, 1, 1};
381-
std::array<unsigned, rank> warpsPerCTA{1,
382-
1,
383-
1,
349+
dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0],
350+
dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]};
351+
std::array<unsigned, rank> sizePerThread{
352+
1, dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0], 1,
353+
1};
354+
std::array<unsigned, rank> threadsPerWarp{dpasEncoding.getExecutionSize(),
355+
1, 1, 1};
356+
std::array<unsigned, rank> warpsPerCTA{1, 1,
384357
dpasEncoding.getWarpsPerCTA()[1],
385-
dpasEncoding.getWarpsPerCTA()[0],
386-
1};
387-
std::array<unsigned, rank> order{0, 1, 2, 3, 4, 5};
358+
dpasEncoding.getWarpsPerCTA()[0]};
359+
constexpr std::array<unsigned, rank> order{0, 1, 2, 3};
388360
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
389361

390362
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
@@ -404,7 +376,8 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
404376

405377
Value performFinalElementwiseReduction(ReduceOp op, PatternRewriter &rewriter,
406378
Value val) const {
407-
return performReduction(op, rewriter, val, /*axis=*/finalEWReductionAxis);
379+
return performReduction(op, rewriter, val,
380+
/*axis=*/finalElementwiseReductionAxis);
408381
}
409382

410383
Value performFinalAcrossWarpsReduction(ReduceOp op, PatternRewriter &rewriter,
@@ -413,43 +386,38 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
413386
/*axis=*/finalWarpsReductionAxis);
414387
}
415388

416-
Value convertLayoutToOriginalType(ReduceOp op, PatternRewriter &rewriter,
417-
Value val,
418-
DpasEncodingAttr dpasEncoding) const {
419-
auto oldType = cast<RankedTensorType>(val.getType());
420-
ArrayRef<int64_t> oldShape = oldType.getShape();
421-
RankedTensorType::Builder type(oldType);
389+
Value reshapeToOriginalType(ReduceOp op, PatternRewriter &rewriter, Value val,
390+
DpasEncodingAttr dpasEncoding) const {
391+
RankedTensorType::Builder type(
392+
cast<RankedTensorType>(op.getResult().front().getType()));
422393

423-
constexpr size_t rank = 5;
394+
constexpr size_t rank = 2;
424395
std::array<unsigned, rank> sizePerThread{
425-
dpasEncoding.getExecutionSize(),
426-
dpasEncoding.getRepCluster()[0] * dpasEncoding.getRepeatCount() /
427-
dpasEncoding.getExecutionSize(),
428-
1, 1, 1};
429-
std::array<unsigned, rank> threadsPerWarp{1, 1, 1, 1,
396+
1, dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] /
397+
dpasEncoding.getExecutionSize()};
398+
std::array<unsigned, rank> threadsPerWarp{1,
430399
dpasEncoding.getExecutionSize()};
431-
std::array<unsigned, rank> warpsPerCTA{1, 1,
432-
dpasEncoding.getWarpsPerCTA()[0], 1,
433-
dpasEncoding.getWarpsPerCTA()[1]};
434-
std::array<unsigned, rank> order{0, 1, 2, 3, 4};
400+
std::array<unsigned, rank> warpsPerCTA{dpasEncoding.getWarpsPerCTA()[1],
401+
dpasEncoding.getWarpsPerCTA()[0]};
402+
constexpr std::array<unsigned, rank> order{0, 1};
435403
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
436404

437405
auto parentEncoding = rewriter.getAttr<BlockedEncodingAttr>(
438406
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
439407

440-
type.setEncoding(parentEncoding.squeeze(rank - 1));
441-
442-
return rewriter.create<ConvertLayoutOp>(
443-
op.getLoc(), static_cast<RankedTensorType>(type), val);
444-
}
408+
type.setEncoding(parentEncoding.squeeze(0));
445409

446-
Value reshapeToOriginalType(ReduceOp op, PatternRewriter &rewriter,
447-
Value val) const {
448410
return rewriter.create<ReshapeOp>(op.getLoc(),
449-
op.getResult().front().getType(), val,
411+
static_cast<RankedTensorType>(type), val,
450412
/*allow_reorder=*/true,
451413
/*efficient_layout=*/true);
452414
}
415+
416+
Value convertLayoutToOriginalType(ReduceOp op, PatternRewriter &rewriter,
417+
Value val) const {
418+
return rewriter.create<ConvertLayoutOp>(
419+
op.getLoc(), op.getResult().front().getType(), val);
420+
}
453421
};
454422

455423
struct TritonIntelGPUOptimizeReductionLocality final

0 commit comments

Comments
 (0)