Skip to content

Commit d662e65

Browse files
authored
[CommonCodeClean]Clean changes in common code (#2950)
Clean changes in common code
1 parent fdab3bb commit d662e65

File tree

2 files changed

+21
-48
lines changed

2 files changed

+21
-48
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -304,16 +304,6 @@ SmallVector<unsigned> getOrder(Attribute layout) {
304304
}
305305
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
306306
auto rank = dotLayout.getWarpsPerCTA().size();
307-
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
308-
// general solution to make `getOrderForDotOperand` function compatible
309-
// with Intel layouts.
310-
// More details:
311-
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
312-
if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
313-
SmallVector<unsigned> order(rank);
314-
std::iota(order.rbegin(), order.rend(), 0);
315-
return order;
316-
}
317307
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
318308
}
319309
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -1093,10 +1083,6 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
10931083
return amdWmmaParent.getTotalElemsPerThreadForOperand(
10941084
shape, eltTy, getKWidth(), getOpIdx());
10951085
}
1096-
if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
1097-
return dpasParent.getTotalElemsPerThreadForOperand(
1098-
shape, eltTy, getKWidth(), getOpIdx());
1099-
}
11001086
}
11011087
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
11021088
auto shapePerCTA = getShapePerCTA(*this, shape);
@@ -1159,17 +1145,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
11591145
return {};
11601146
}
11611147
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
1162-
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
1163-
// general solution to make `getOrderForDotOperand` function compatible
1164-
// with Intel layouts.
1165-
// More details:
1166-
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
1167-
if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
1168-
return ::getOrder(*this);
1169-
} else {
1170-
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1171-
/*kMajor*/ true);
1172-
}
1148+
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1149+
/*kMajor*/ true);
11731150
}
11741151

11751152
LogicalResult DotOperandEncodingAttr::verify(

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,21 @@ struct LoadOpConversion
526526
};
527527
auto opIdx = getOpIdx();
528528

529+
std::optional<LinearLayout> llEncoding =
530+
cast<DistributedEncodingTrait>(encoding).toLinearLayout(
531+
tensorType.getShape());
532+
assert(llEncoding.has_value() && "invalid dot layout to linear layout");
533+
LinearEncodingAttr llAttr =
534+
LinearEncodingAttr::get(rewriter.getContext(), *llEncoding);
535+
SmallVector<unsigned> threadOrder = llAttr.getThreadOrder();
536+
size_t rank = threadOrder.size();
537+
const bool valueRowMajor =
538+
(threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0);
539+
assert((valueRowMajor ||
540+
(threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) &&
541+
"Only row_major or column_major is allowed");
542+
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
543+
529544
Type eltTy = tensorType.getElementType();
530545
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
531546

@@ -539,15 +554,15 @@ struct LoadOpConversion
539554
SmallVector<int64_t> numReps =
540555
dpasLayout.getDPASRepetitions(tensorShape, opIdx);
541556
const SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
542-
SmallVector<unsigned> dpasOrder = triton::gpu::getOrder(dpasLayout);
557+
SmallVector<unsigned> dpasWarpsOrder = triton::gpu::getOrder(dpasLayout);
543558
int threadsPerWarp = triton::gpu::getWarpSize(dpasLayout);
544559

545560
Value warpId = rewriter.create<arith::IndexCastOp>(
546561
loc, i32_ty,
547562
rewriter.create<mlir::gpu::SubgroupIdOp>(loc, /*upperBound=*/nullptr));
548563

549564
SmallVector<Value> multiDimWarpId =
550-
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder);
565+
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder);
551566

552567
if (hasDpasLayout) {
553568
// A block load with the DPAS layout but without the DotDpasLayout is
@@ -557,14 +572,6 @@ struct LoadOpConversion
557572
// aligns to the DPAS layout as the DPAS operation output layout
558573
// distributes rows across work items.
559574

560-
size_t rank = dpasOrder.size();
561-
const bool valueRowMajor =
562-
(dpasOrder[rank - 2] == 1 && dpasOrder[rank - 1] == 0);
563-
assert((valueRowMajor ||
564-
(dpasOrder[rank - 2] == 0 && dpasOrder[rank - 1] == 1)) &&
565-
"Only row_major or column_major is allowed");
566-
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
567-
568575
if (isTransposeRequired) {
569576
// TODO: this would likely require a shuffle to match the expected
570577
// ordering coming out of the DPAS layout and requires more
@@ -675,17 +682,6 @@ struct LoadOpConversion
675682
return success();
676683
}
677684

678-
DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
679-
auto dotOrder = dotLayout.getThreadOrder();
680-
681-
size_t rank = dotOrder.size();
682-
const bool valueRowMajor =
683-
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
684-
assert((valueRowMajor ||
685-
(dotOrder[rank - 2] == 0 && dotOrder[rank - 1] == 1)) &&
686-
"Only row_major or column_major is allowed");
687-
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
688-
689685
bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA);
690686
SmallVector<unsigned> dpasInstShape = isOperandA
691687
? dpasLayout.getDPASInstShapeA()
@@ -749,8 +745,8 @@ struct LoadOpConversion
749745
offsetBaseY] =
750746
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
751747

752-
unsigned tileWidth = elemsPerDPASInst[dotOrder[rank - 2]];
753-
unsigned tileHeight = elemsPerDPASInst[dotOrder[rank - 1]];
748+
unsigned tileWidth = elemsPerDPASInst[threadOrder[rank - 2]];
749+
unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1]];
754750
unsigned vBlocks = 1;
755751
unsigned numOperandsOuterDimPerLoad = 1;
756752
unsigned numOperandsInnerDimPerLoad = 1;

0 commit comments

Comments
 (0)