@@ -1386,7 +1386,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
1386
1386
if (result.failed ())
1387
1387
return ;
1388
1388
1389
- Operation *extOrBroadcatOp = nullptr ;
1389
+ Operation *extOrBroadcastOp = nullptr ;
1390
1390
unsigned sliceSize = slice.size ();
1391
1391
for (unsigned i = 0 ; i < sliceSize; i++) {
1392
1392
Value v = slice[i];
@@ -1410,37 +1410,37 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
1410
1410
}
1411
1411
// Only apply it if there is a single ext op otherwise we would have to
1412
1412
// duplicate the convert.
1413
- if (extOrBroadcatOp != nullptr )
1413
+ if (extOrBroadcastOp != nullptr )
1414
1414
return ;
1415
- extOrBroadcatOp = op;
1415
+ extOrBroadcastOp = op;
1416
1416
}
1417
1417
}
1418
1418
1419
- if (extOrBroadcatOp == nullptr )
1419
+ if (extOrBroadcastOp == nullptr )
1420
1420
return ;
1421
- Attribute dstEncoding = layout[extOrBroadcatOp ->getResult (0 )];
1422
- Attribute srcEncoding = ttgi::inferSrcEncoding (extOrBroadcatOp , dstEncoding);
1421
+ Attribute dstEncoding = layout[extOrBroadcastOp ->getResult (0 )];
1422
+ Attribute srcEncoding = ttgi::inferSrcEncoding (extOrBroadcastOp , dstEncoding);
1423
1423
if (!srcEncoding)
1424
1424
return ;
1425
1425
// Move the convert before the ext op and rewrite the slice.
1426
- OpBuilder builder (extOrBroadcatOp );
1426
+ OpBuilder builder (extOrBroadcastOp );
1427
1427
auto tensorType =
1428
- cast<RankedTensorType>(extOrBroadcatOp ->getOperand (0 ).getType ());
1428
+ cast<RankedTensorType>(extOrBroadcastOp ->getOperand (0 ).getType ());
1429
1429
auto newType = RankedTensorType::get (
1430
1430
tensorType.getShape (), tensorType.getElementType (), srcEncoding);
1431
1431
auto newConvertOp = builder.create <ConvertLayoutOp>(
1432
- convertOp.getLoc (), newType, extOrBroadcatOp ->getOperand (0 ));
1433
- Operation *newExtOrBroadcast = builder.clone (*extOrBroadcatOp );
1432
+ convertOp.getLoc (), newType, extOrBroadcastOp ->getOperand (0 ));
1433
+ Operation *newExtOrBroadcast = builder.clone (*extOrBroadcastOp );
1434
1434
newExtOrBroadcast->setOperand (0 , newConvertOp.getResult ());
1435
1435
auto oldExtOrBroadcastType =
1436
- cast<RankedTensorType>(extOrBroadcatOp ->getResult (0 ).getType ());
1436
+ cast<RankedTensorType>(extOrBroadcastOp ->getResult (0 ).getType ());
1437
1437
Type newExtOrBroadcasrType = RankedTensorType::get (
1438
1438
oldExtOrBroadcastType.getShape (), oldExtOrBroadcastType.getElementType (),
1439
1439
dstEncoding);
1440
1440
newExtOrBroadcast->getResult (0 ).setType (newExtOrBroadcasrType);
1441
1441
IRMapping mapping;
1442
- mapping.map (extOrBroadcatOp ->getResult (0 ), newExtOrBroadcast->getResult (0 ));
1443
- slice.remove (extOrBroadcatOp ->getResult (0 ));
1442
+ mapping.map (extOrBroadcastOp ->getResult (0 ), newExtOrBroadcast->getResult (0 ));
1443
+ slice.remove (extOrBroadcastOp ->getResult (0 ));
1444
1444
// 3. Rewrite the slice.
1445
1445
rewriteSlice (slice, layout, convertOp, mapping);
1446
1446
}
0 commit comments