@@ -302,7 +302,8 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
302
302
303
303
// Only lower loadOp with dpas layout encoding.
304
304
auto tensorTy = cast<RankedTensorType>(op.getType ());
305
- return hasDpasEncoding (tensorTy) || hasDotDpasEncoding (tensorTy);
305
+ return hasDpasEncoding (tensorTy) || hasDotDpasEncoding (tensorTy) ||
306
+ hasSubgroup2DBlockEncoding (tensorTy);
306
307
}
307
308
308
309
template <
@@ -1416,12 +1417,31 @@ struct LoadOpConversion
1416
1417
auto tensorType = cast<RankedTensorType>(resultType);
1417
1418
1418
1419
const bool memoryRowMajor = isMemoryRowMajor (op);
1419
- DpasEncodingAttr::OpIdx opIdx = getOpIdx (tensorType);
1420
+
1421
+ auto getDpasTypeFromCVTOp = [&](Value opResult) -> RankedTensorType {
1422
+ for (OpOperand user : opResult.getUsers ()) {
1423
+ if (auto cvt = dyn_cast<ConvertLayoutOp>(user.getOwner ())) {
1424
+ return cast<RankedTensorType>(cvt.getResult ().getType ());
1425
+ // return getDpasLayout(cvt.getResult().getType());
1426
+ }
1427
+ }
1428
+ llvm_unreachable (" expected to find a cvt op with dpas layout" );
1429
+ };
1430
+
1431
+ auto dpasTensorType = hasSubgroup2DBlockEncoding (tensorType)
1432
+ ? getDpasTypeFromCVTOp (op.getResult ())
1433
+ : tensorType;
1434
+ llvm::errs () << " using dpas tensor type: " << dpasTensorType << " \n " ;
1435
+ DpasEncodingAttr dpasLayout = getDpasLayout (dpasTensorType);
1436
+
1437
+ DpasEncodingAttr::OpIdx opIdx = getOpIdx (dpasTensorType);
1420
1438
1421
1439
LLVM_DEBUG (llvm::dbgs () << " Tensor type for op " << int (opIdx) << " : "
1422
1440
<< tensorType << " \n " );
1423
1441
1424
1442
Attribute encoding = tensorType.getEncoding ();
1443
+ // TODO: this gives us the linear layour corresponding
1444
+ // to the subgroup 2d block encoding, not the dpas encoding...
1425
1445
std::optional<LinearLayout> llEncoding =
1426
1446
cast<DistributedEncodingTrait>(encoding).toLinearLayout (
1427
1447
tensorType.getShape ());
@@ -1440,14 +1460,21 @@ struct LoadOpConversion
1440
1460
Type eltTy = tensorType.getElementType ();
1441
1461
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
1442
1462
1443
- auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout (
1444
- cast<DistributedEncodingTrait>(encoding), tensorType.getShape (),
1445
- memoryRowMajor, elemSizeInBits / 8 , rewriter.getContext ());
1446
- unsigned tileHeight = tileParams[0 ];
1447
- const unsigned tileWidth = tileParams[1 ];
1448
- const unsigned vBlocks = tileParams[2 ];
1463
+ auto getTileParams = [&]() -> std::tuple<unsigned , unsigned , unsigned > {
1464
+ if (hasSubgroup2DBlockEncoding (tensorType)) {
1465
+ auto encoding =
1466
+ cast<Subgroup2DBlockEncodingAttr>(tensorType.getEncoding ());
1467
+ auto shape = encoding.getInstrShape ();
1468
+ return std::make_tuple (shape[0 ], shape[1 ], encoding.getNumBlocks ());
1469
+ } else {
1470
+ auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout (
1471
+ cast<DistributedEncodingTrait>(encoding), tensorType.getShape (),
1472
+ memoryRowMajor, elemSizeInBits / 8 , rewriter.getContext ());
1473
+ return std::make_tuple (tileParams[0 ], tileParams[1 ], tileParams[2 ]);
1474
+ }
1475
+ };
1476
+ auto [tileHeight, tileWidth, vBlocks] = getTileParams ();
1449
1477
1450
- DpasEncodingAttr dpasLayout = getDpasLayout (tensorType);
1451
1478
const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
1452
1479
unsigned numElems = getTotalElemsPerThread (resultType);
1453
1480
SmallVector<int64_t > numReps =
@@ -1617,6 +1644,7 @@ struct LoadOpConversion
1617
1644
// input operands to DPAS.
1618
1645
// TODO: add support for int4 and int2.
1619
1646
unsigned opsPerChannel = dpasLayout.getOpsPerChannel ();
1647
+ llvm::errs () << " opsPerChannel = " << opsPerChannel << " \n " ;
1620
1648
if ((opsPerChannel == 4 && elemSizeInBits == 8 ) ||
1621
1649
(opsPerChannel == 2 && elemSizeInBits == 16 ) ||
1622
1650
(opsPerChannel == 1 && elemSizeInBits == 32 )) {
@@ -1840,6 +1868,8 @@ struct LoadOpConversion
1840
1868
unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst *
1841
1869
numOperandsOuterDimPerLoad *
1842
1870
numOperandsInnerDimPerLoad;
1871
+ llvm::errs () << " num values per load = " << numValuesPerLoad << " \n " ;
1872
+ llvm::errs () << " loadResultElemType = " << loadResultElemType << " \n " ;
1843
1873
Type load2DGenXType =
1844
1874
LLVM::getVectorType (loadResultElemType, numValuesPerLoad);
1845
1875
@@ -2187,6 +2217,8 @@ struct LoadOpConversion
2187
2217
}
2188
2218
2189
2219
Type llvmResultStructTy = typeConverter->convertType (op.getType ());
2220
+ llvm::errs () << " op.getType() " << op.getType () << " \n " ;
2221
+ llvm::errs () << " llvmResultStructTy: " << llvmResultStructTy << " \n " ;
2190
2222
Value resultStruct = packLLElements (loc, typeConverter, unpackedLoadedVals,
2191
2223
rewriter, llvmResultStructTy);
2192
2224
rewriter.replaceOp (op, {resultStruct});
0 commit comments