Skip to content

Commit 3c154c8

Browse files
authored
[RemoveLayoutConversions]: Sync with upstream (#4895)
Replaces manual `RankedTensorType::get()` calls with the more concise `cloneWithEncoding()` method, and sync with upstream. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 4199e4e commit 3c154c8

File tree

1 file changed

+54
-43
lines changed

1 file changed

+54
-43
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ class LayoutRematerialization {
176176
SetVector<Operation *> opToDelete;
177177
FuncOp funcOp;
178178
DominanceInfo domInfo;
179+
PostDominanceInfo postDomInfo;
179180
};
180181

181182
void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
@@ -525,8 +526,7 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
525526
return rewrittenValue;
526527
OpBuilder rewriter(value.getContext());
527528
rewriter.setInsertionPointAfterValue(rewrittenValue);
528-
auto tmpType = RankedTensorType::get(tensorType.getShape(),
529-
tensorType.getElementType(), encoding);
529+
auto tmpType = tensorType.cloneWithEncoding(encoding);
530530
Value converted = rewriter.create<ConvertLayoutOp>(value.getLoc(), tmpType,
531531
rewrittenValue);
532532
// TODO: we could cache the conversion.
@@ -567,8 +567,7 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
567567
auto origType = dyn_cast<RankedTensorType>(op->getResult(i).getType());
568568
if (!origType)
569569
continue;
570-
auto newType = RankedTensorType::get(origType.getShape(),
571-
origType.getElementType(), encoding);
570+
auto newType = origType.cloneWithEncoding(encoding);
572571
newOp->getResult(i).setType(newType);
573572
}
574573
return newOp;
@@ -631,9 +630,7 @@ Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) {
631630
continue;
632631
}
633632
auto origType = dyn_cast<RankedTensorType>(ret.getType());
634-
auto newType =
635-
RankedTensorType::get(origType.getShape(), origType.getElementType(),
636-
it->second.encodings[0]);
633+
auto newType = origType.cloneWithEncoding(it->second.encodings[0]);
637634
returnTypes.push_back(newType);
638635
}
639636

@@ -683,8 +680,7 @@ Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) {
683680
continue;
684681
auto origType = cast<RankedTensorType>(ifOp->getResult(i).getType());
685682
Attribute encoding = *(it->second.encodings.begin());
686-
newResultTypes[i] = RankedTensorType::get(
687-
origType.getShape(), origType.getElementType(), encoding);
683+
newResultTypes[i] = origType.cloneWithEncoding(encoding);
688684
}
689685
auto newIfOp = rewriter.create<scf::IfOp>(ifOp.getLoc(), newResultTypes,
690686
ifOp.getCondition(), true, true);
@@ -940,17 +936,15 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
940936
srcEncoding = *(it->second.encodings.begin());
941937
Value src = getValueAs(convertOp.getSrc(), srcEncoding);
942938
auto tensorType = cast<RankedTensorType>(op->getResult(0).getType());
943-
auto newType = RankedTensorType::get(tensorType.getShape(),
944-
tensorType.getElementType(), encoding);
939+
auto newType = tensorType.cloneWithEncoding(encoding);
945940
auto cvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), newType, src);
946941
map(op->getResult(0), cvt.getResult());
947942
return cvt.getOperation();
948943
}
949944
if (canFoldIntoConversion(op, encoding)) {
950945
Operation *newOp = rewriter.clone(*op);
951946
auto tensorType = cast<RankedTensorType>(op->getResult(0).getType());
952-
auto newType = RankedTensorType::get(tensorType.getShape(),
953-
tensorType.getElementType(), encoding);
947+
auto newType = tensorType.cloneWithEncoding(encoding);
954948
auto cvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), newType,
955949
newOp->getResult(0));
956950
map(op->getResult(0), cvt.getResult());
@@ -1111,14 +1105,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11111105

11121106
Type resType = res.getType();
11131107
if (auto oldType = dyn_cast<RankedTensorType>(resType)) {
1114-
auto newType = RankedTensorType::get(
1115-
oldType.getShape(), oldType.getElementType(), it->second);
1108+
Type newType = oldType.cloneWithEncoding(it->second);
11161109
newTypes.push_back(newType);
11171110
} else if (auto ptrType = dyn_cast<PointerType>(resType)) {
11181111
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
1119-
auto newType = triton::PointerType::get(
1120-
RankedTensorType::get(tensorType.getShape(),
1121-
tensorType.getElementType(), it->second),
1112+
Type newType = triton::PointerType::get(
1113+
tensorType.cloneWithEncoding(it->second),
11221114
ptrType.getAddressSpace());
11231115
newTypes.push_back(newType);
11241116
}
@@ -1158,9 +1150,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11581150
if (isa<arith::ConstantOp>(op)) {
11591151
Operation *newOp = builder.clone(*op);
11601152
auto tensorType = cast<RankedTensorType>(op->getResult(0).getType());
1161-
auto newType = RankedTensorType::get(tensorType.getShape(),
1162-
tensorType.getElementType(),
1163-
layout[op->getResult(0)]);
1153+
auto newType = tensorType.cloneWithEncoding(layout[op->getResult(0)]);
11641154
auto cvt = builder.create<ConvertLayoutOp>(op->getLoc(), newType,
11651155
newOp->getResult(0));
11661156
mapping.map(op->getResult(0), cvt.getResult());
@@ -1178,14 +1168,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11781168
if (isTensorPointerType(oldType)) {
11791169
auto ptrType = cast<PointerType>(oldType);
11801170
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
1181-
newType = triton::PointerType::get(
1182-
RankedTensorType::get(tensorType.getShape(),
1183-
tensorType.getElementType(), it->second),
1184-
ptrType.getAddressSpace());
1171+
newType =
1172+
triton::PointerType::get(tensorType.cloneWithEncoding(it->second),
1173+
ptrType.getAddressSpace());
11851174
} else {
1186-
newType = RankedTensorType::get(
1187-
cast<RankedTensorType>(old.getType()).getShape(),
1188-
cast<RankedTensorType>(old.getType()).getElementType(), it->second);
1175+
newType =
1176+
cast<RankedTensorType>(old.getType()).cloneWithEncoding(it->second);
11891177
}
11901178
newV.setType(newType);
11911179
addRematValue(old, it->second, newV);
@@ -1522,13 +1510,40 @@ void LayoutRematerialization::hoistConvertDotOperand() {
15221510
void LayoutRematerialization::hoistConvertDotOperand(
15231511
ConvertLayoutOp convertOp) {
15241512
auto targetType = convertOp.getType();
1525-
// The pass is targeted to Nvidia mma/wgmma dot operands
1513+
// The pass is targeted to MMA dot operands
1514+
1515+
auto canBePipelined = [&](ConvertLayoutOp convertOp) {
1516+
// FIXME: Check that the parent is a for loop
1517+
auto parent = convertOp->getParentOp();
1518+
if (!parent)
1519+
return false;
1520+
1521+
// Find all the dot-like ops in the for loop that have a dot operand
1522+
// encoding on the lhs and check if any of them post-dominates the load +
1523+
// cvt
1524+
SmallVector<Operation *> dotLikeOps;
1525+
parent->walk([&](Operation *op) {
1526+
if (!isa<mlir::triton::DotOpInterface>(op))
1527+
return;
1528+
auto opType = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1529+
if (!opType)
1530+
return;
1531+
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(opType.getEncoding());
1532+
if (!dotEnc)
1533+
return;
1534+
if (isa<MmaEncodingTrait>(dotEnc.getParent()))
1535+
dotLikeOps.push_back(op);
1536+
});
1537+
if (dotLikeOps.empty())
1538+
return false;
1539+
return llvm::any_of(dotLikeOps, [&](Operation *dot) {
1540+
return postDomInfo.postDominates(dot, convertOp);
1541+
});
1542+
};
1543+
15261544
// We move convert #dot_operand next to their loads. This is done
15271545
// so that it's then easy to pipeline these loads
1528-
// TODO: Perhaps we should do this whenever convertOp is within a loop
1529-
1530-
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding());
1531-
if (!(dotEnc && isa<NvidiaMmaEncodingAttr>(dotEnc.getParent())))
1546+
if (!canBePipelined(convertOp))
15321547
return;
15331548

15341549
// We hoist over any operation that can be done without data movement between
@@ -1579,8 +1594,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
15791594
auto type = dyn_cast<RankedTensorType>(loadOp->getResult(0).getType());
15801595
if (!type)
15811596
continue;
1582-
auto newType = RankedTensorType::get(type.getShape(), type.getElementType(),
1583-
layout[loadOp->getResult(0)]);
1597+
auto newType = type.cloneWithEncoding(layout[loadOp->getResult(0)]);
15841598
auto newConvertOp = builder.create<ConvertLayoutOp>(
15851599
convertOp.getLoc(), newType, loadOp->getResult(0));
15861600
mapping.map(loadOp->getResult(0), newConvertOp.getResult());
@@ -1682,18 +1696,16 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
16821696
OpBuilder builder(extOrBroadcastOp);
16831697
auto tensorType =
16841698
cast<RankedTensorType>(extOrBroadcastOp->getOperand(0).getType());
1685-
auto newType = RankedTensorType::get(
1686-
tensorType.getShape(), tensorType.getElementType(), srcEncoding);
1699+
auto newType = tensorType.cloneWithEncoding(srcEncoding);
16871700
auto newConvertOp = builder.create<ConvertLayoutOp>(
16881701
convertOp.getLoc(), newType, extOrBroadcastOp->getOperand(0));
16891702
Operation *newExtOrBroadcast = builder.clone(*extOrBroadcastOp);
16901703
newExtOrBroadcast->setOperand(0, newConvertOp.getResult());
16911704
auto oldExtOrBroadcastType =
16921705
cast<RankedTensorType>(extOrBroadcastOp->getResult(0).getType());
1693-
Type newExtOrBroadcasrType = RankedTensorType::get(
1694-
oldExtOrBroadcastType.getShape(), oldExtOrBroadcastType.getElementType(),
1695-
dstEncoding);
1696-
newExtOrBroadcast->getResult(0).setType(newExtOrBroadcasrType);
1706+
Type newExtOrBroadcastType =
1707+
oldExtOrBroadcastType.cloneWithEncoding(dstEncoding);
1708+
newExtOrBroadcast->getResult(0).setType(newExtOrBroadcastType);
16971709
IRMapping mapping;
16981710
mapping.map(extOrBroadcastOp->getResult(0), newExtOrBroadcast->getResult(0));
16991711
slice.remove(extOrBroadcastOp->getResult(0));
@@ -1798,8 +1810,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
17981810
IRMapping mapping;
17991811
auto hoistRemat = [&](OpBuilder &b, Value v, Attribute encoding) {
18001812
auto tensorType = cast<RankedTensorType>(v.getType());
1801-
auto newType = RankedTensorType::get(tensorType.getShape(),
1802-
tensorType.getElementType(), encoding);
1813+
auto newType = tensorType.cloneWithEncoding(encoding);
18031814
Value newCvt = b.create<ConvertLayoutOp>(convertOp.getLoc(), newType, v);
18041815

18051816
mapping.map(v, newCvt);

0 commit comments

Comments
 (0)