@@ -201,6 +201,9 @@ bool isLayoutAnchor(Operation *op) {
201
201
// AtomicCAS for further performance consideration.
202
202
if (isa<DotOp, DotScaledOp, AtomicCASOp>(op))
203
203
return true ;
204
+ if (auto gatherOp = dyn_cast<GatherOp>(op))
205
+ return gatherOp.getEfficientLayout ();
206
+
204
207
if (isa<AtomicRMWOp>(op))
205
208
if (auto tensorType =
206
209
dyn_cast<RankedTensorType>(op->getResult (0 ).getType ()))
@@ -529,7 +532,19 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
529
532
530
533
Attribute operandEnc;
531
534
if (op->getNumOperands () > 0 ) {
532
- operandEnc = ttgi::inferSrcEncoding (op, encoding);
535
+ for (auto operand : op->getOperands ()) {
536
+ auto ty =
537
+ dyn_cast<RankedTensorType>(getRewrittenValue (operand).getType ());
538
+ if (!ty)
539
+ continue ;
540
+ auto enc = ty.getEncoding ();
541
+ if (inferDstEncoding (op, enc) == encoding) {
542
+ operandEnc = enc;
543
+ break ;
544
+ }
545
+ }
546
+ if (!operandEnc)
547
+ operandEnc = ttgi::inferSrcEncoding (op, encoding);
533
548
assert (operandEnc);
534
549
}
535
550
@@ -1408,8 +1423,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
1408
1423
// threads We do views and elementwise pure ops for now
1409
1424
auto noDataMovement = [](Operation *op) {
1410
1425
return (op->hasTrait <OpTrait::Elementwise>() && isMemoryEffectFree (op)) ||
1411
- isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, Fp4ToFpOp,
1412
- ConvertLayoutOp>(op);
1426
+ isa<BroadcastOp, Fp4ToFpOp, ConvertLayoutOp>(op) || isView (op);
1413
1427
};
1414
1428
// Stop the slice as soon as we find an operation that cannot be done without
1415
1429
// data movement between threads
0 commit comments