Skip to content

Commit 4553e88

Browse files
authored
[RemoveLayoutConversion]: Sync implementation with upstream (#4693)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent e423617 commit 4553e88

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ bool isLayoutAnchor(Operation *op) {
201201
// AtomicCAS for further performance consideration.
202202
if (isa<DotOp, DotScaledOp, AtomicCASOp>(op))
203203
return true;
204+
if (auto gatherOp = dyn_cast<GatherOp>(op))
205+
return gatherOp.getEfficientLayout();
206+
204207
if (isa<AtomicRMWOp>(op))
205208
if (auto tensorType =
206209
dyn_cast<RankedTensorType>(op->getResult(0).getType()))
@@ -529,7 +532,19 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
529532

530533
Attribute operandEnc;
531534
if (op->getNumOperands() > 0) {
532-
operandEnc = ttgi::inferSrcEncoding(op, encoding);
535+
for (auto operand : op->getOperands()) {
536+
auto ty =
537+
dyn_cast<RankedTensorType>(getRewrittenValue(operand).getType());
538+
if (!ty)
539+
continue;
540+
auto enc = ty.getEncoding();
541+
if (inferDstEncoding(op, enc) == encoding) {
542+
operandEnc = enc;
543+
break;
544+
}
545+
}
546+
if (!operandEnc)
547+
operandEnc = ttgi::inferSrcEncoding(op, encoding);
533548
assert(operandEnc);
534549
}
535550

@@ -1408,8 +1423,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
14081423
// threads We do views and elementwise pure ops for now
14091424
auto noDataMovement = [](Operation *op) {
14101425
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
1411-
isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, Fp4ToFpOp,
1412-
ConvertLayoutOp>(op);
1426+
isa<BroadcastOp, Fp4ToFpOp, ConvertLayoutOp>(op) || isView(op);
14131427
};
14141428
// Stop the slice as soon as we find an operation that cannot be done without
14151429
// data movement between threads

0 commit comments

Comments
 (0)