Skip to content

Commit 5f06efc

Browse files
committed
Fix red
1 parent 6043bce commit 5f06efc

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,15 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
212212
<< "Performed elementwise reduction within repCount: " << operand
213213
<< "\n");
214214

215-
operand = reshapeForFinalReduction(op, rewriter, operand, encoding);
215+
operand = convertLayoutForFinalReduction(op, rewriter, operand, encoding);
216216

217217
LLVM_DEBUG(llvm::dbgs()
218-
<< "Reshaped for final reduction: " << operand << "\n");
218+
<< "Converted layout for final reduction: " << operand << "\n");
219219

220-
operand = convertLayoutForFinalReduction(op, rewriter, operand, encoding);
220+
operand = reshapeForFinalReduction(op, rewriter, operand, encoding);
221221

222222
LLVM_DEBUG(llvm::dbgs()
223-
<< "Converted layout for final reduction: " << operand << "\n");
223+
<< "Reshaped for final reduction: " << operand << "\n");
224224

225225
operand = performFinalElementwiseReduction(op, rewriter, operand);
226226

@@ -315,21 +315,25 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
315315
Value val,
316316
DpasEncodingAttr dpasEncoding) const {
317317
auto oldType = cast<RankedTensorType>(val.getType());
318-
auto oldEncoding = cast<BlockedEncodingAttr>(oldType.getEncoding());
319318
RankedTensorType::Builder type(oldType);
320319

321-
constexpr size_t rank = 4;
320+
constexpr size_t rank = 5;
322321
std::array<unsigned, rank> sizePerThread{
323322
dpasEncoding.getExecutionSize(),
324323
dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] /
325324
dpasEncoding.getExecutionSize(),
326-
1, 1};
325+
1, 1, 1};
327326
std::array<unsigned, rank> threadsPerWarp{
328-
1, dpasEncoding.getExecutionSize(), 1, 1};
327+
1, dpasEncoding.getExecutionSize() / dpasEncoding.getRepCluster()[0],
328+
dpasEncoding.getRepCluster()[0], 1, 1};
329+
std::array<unsigned, rank> warpsPerCTA{1, 1, 1,
330+
dpasEncoding.getWarpsPerCTA()[1],
331+
dpasEncoding.getWarpsPerCTA()[0]};
332+
constexpr std::array<unsigned, rank> order{0, 1, 2, 3, 4};
333+
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
329334

330335
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
331-
sizePerThread, threadsPerWarp, oldEncoding.getWarpsPerCTA(),
332-
oldEncoding.getOrder(), oldEncoding.getCTALayout());
336+
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
333337

334338
type.setEncoding(encoding);
335339

@@ -349,10 +353,12 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
349353
dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0],
350354
dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]};
351355
std::array<unsigned, rank> sizePerThread{
352-
1, dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0], 1,
353-
1};
354-
std::array<unsigned, rank> threadsPerWarp{dpasEncoding.getExecutionSize(),
355-
1, 1, 1};
356+
dpasEncoding.getExecutionSize(),
357+
dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] /
358+
dpasEncoding.getExecutionSize(),
359+
1, 1};
360+
std::array<unsigned, rank> threadsPerWarp{
361+
1, dpasEncoding.getExecutionSize(), 1, 1};
356362
std::array<unsigned, rank> warpsPerCTA{1, 1,
357363
dpasEncoding.getWarpsPerCTA()[1],
358364
dpasEncoding.getWarpsPerCTA()[0]};

0 commit comments

Comments
 (0)