Skip to content

Commit 81a0018

Browse files
authored
[intel] Update RemoveLayoutConversions with changes from a1b44c6 (#4351)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 4e94523 commit 81a0018

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,7 +1386,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
13861386
if (result.failed())
13871387
return;
13881388

1389-
Operation *extOrBroadcatOp = nullptr;
1389+
Operation *extOrBroadcastOp = nullptr;
13901390
unsigned sliceSize = slice.size();
13911391
for (unsigned i = 0; i < sliceSize; i++) {
13921392
Value v = slice[i];
@@ -1410,37 +1410,37 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14101410
}
14111411
// Only apply it if there is a single ext op otherwise we would have to
14121412
// duplicate the convert.
1413-
if (extOrBroadcatOp != nullptr)
1413+
if (extOrBroadcastOp != nullptr)
14141414
return;
1415-
extOrBroadcatOp = op;
1415+
extOrBroadcastOp = op;
14161416
}
14171417
}
14181418

1419-
if (extOrBroadcatOp == nullptr)
1419+
if (extOrBroadcastOp == nullptr)
14201420
return;
1421-
Attribute dstEncoding = layout[extOrBroadcatOp->getResult(0)];
1422-
Attribute srcEncoding = ttgi::inferSrcEncoding(extOrBroadcatOp, dstEncoding);
1421+
Attribute dstEncoding = layout[extOrBroadcastOp->getResult(0)];
1422+
Attribute srcEncoding = ttgi::inferSrcEncoding(extOrBroadcastOp, dstEncoding);
14231423
if (!srcEncoding)
14241424
return;
14251425
// Move the convert before the ext op and rewrite the slice.
1426-
OpBuilder builder(extOrBroadcatOp);
1426+
OpBuilder builder(extOrBroadcastOp);
14271427
auto tensorType =
1428-
cast<RankedTensorType>(extOrBroadcatOp->getOperand(0).getType());
1428+
cast<RankedTensorType>(extOrBroadcastOp->getOperand(0).getType());
14291429
auto newType = RankedTensorType::get(
14301430
tensorType.getShape(), tensorType.getElementType(), srcEncoding);
14311431
auto newConvertOp = builder.create<ConvertLayoutOp>(
1432-
convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0));
1433-
Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp);
1432+
convertOp.getLoc(), newType, extOrBroadcastOp->getOperand(0));
1433+
Operation *newExtOrBroadcast = builder.clone(*extOrBroadcastOp);
14341434
newExtOrBroadcast->setOperand(0, newConvertOp.getResult());
14351435
auto oldExtOrBroadcastType =
1436-
cast<RankedTensorType>(extOrBroadcatOp->getResult(0).getType());
1436+
cast<RankedTensorType>(extOrBroadcastOp->getResult(0).getType());
14371437
Type newExtOrBroadcasrType = RankedTensorType::get(
14381438
oldExtOrBroadcastType.getShape(), oldExtOrBroadcastType.getElementType(),
14391439
dstEncoding);
14401440
newExtOrBroadcast->getResult(0).setType(newExtOrBroadcasrType);
14411441
IRMapping mapping;
1442-
mapping.map(extOrBroadcatOp->getResult(0), newExtOrBroadcast->getResult(0));
1443-
slice.remove(extOrBroadcatOp->getResult(0));
1442+
mapping.map(extOrBroadcastOp->getResult(0), newExtOrBroadcast->getResult(0));
1443+
slice.remove(extOrBroadcastOp->getResult(0));
14441444
// 3. Rewrite the slice.
14451445
rewriteSlice(slice, layout, convertOp, mapping);
14461446
}

0 commit comments

Comments
 (0)