@@ -343,6 +343,15 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
343
343
: getDotEncoding (tensorTy).value ().getParent ());
344
344
}
345
345
346
+ static RankedTensorType getDpasTypeFromCVTOp (Value opResult) {
347
+ for (OpOperand user : opResult.getUsers ()) {
348
+ if (auto cvt = dyn_cast<ConvertLayoutOp>(user.getOwner ())) {
349
+ return cast<RankedTensorType>(cvt.getResult ().getType ());
350
+ }
351
+ }
352
+ llvm_unreachable (" expected to find a cvt op with dpas layout" );
353
+ }
354
+
346
355
// Returns the pitch (stride in bytes) of \p ptr.
347
356
Value getPitch (ConversionPatternRewriter &rewriter, Value ptr,
348
357
const std::map<SmallVector<unsigned >, Value> &ptrs,
@@ -1418,16 +1427,6 @@ struct LoadOpConversion
1418
1427
1419
1428
const bool memoryRowMajor = isMemoryRowMajor (op);
1420
1429
1421
- auto getDpasTypeFromCVTOp = [&](Value opResult) -> RankedTensorType {
1422
- for (OpOperand user : opResult.getUsers ()) {
1423
- if (auto cvt = dyn_cast<ConvertLayoutOp>(user.getOwner ())) {
1424
- return cast<RankedTensorType>(cvt.getResult ().getType ());
1425
- // return getDpasLayout(cvt.getResult().getType());
1426
- }
1427
- }
1428
- llvm_unreachable (" expected to find a cvt op with dpas layout" );
1429
- };
1430
-
1431
1430
auto dpasTensorType = hasSubgroup2DBlockEncoding (tensorType)
1432
1431
? getDpasTypeFromCVTOp (op.getResult ())
1433
1432
: tensorType;
@@ -2213,6 +2212,8 @@ struct LoadOpConversion
2213
2212
}
2214
2213
2215
2214
Type llvmResultStructTy = typeConverter->convertType (op.getType ());
2215
+ LLVM_DEBUG (llvm::dbgs () << " Packing load result in struct "
2216
+ << llvmResultStructTy << " \n " );
2216
2217
Value resultStruct = packLLElements (loc, typeConverter, unpackedLoadedVals,
2217
2218
rewriter, llvmResultStructTy);
2218
2219
rewriter.replaceOp (op, {resultStruct});
@@ -2235,10 +2236,16 @@ struct LoadOpConversion
2235
2236
Value mask = op.getMask ();
2236
2237
Value llMask = adaptor.getMask ();
2237
2238
2239
+ auto opType = op.getType ();
2240
+ // TODO: Override the OpType since conversion is still happening during Load
2241
+ // lowering. Once we materialize ConvertLayoutOp this can be removed.
2242
+ if (auto tensorTy = dyn_cast<RankedTensorType>(opType);
2243
+ hasSubgroup2DBlockEncoding (tensorTy))
2244
+ opType = getDpasTypeFromCVTOp (op.getResult ());
2245
+
2238
2246
// Determine the vectorization size
2239
- Type valueElemTy =
2240
- typeConverter->convertType (getElementTypeOrSelf (op.getType ()));
2241
- unsigned numElems = getTotalElemsPerThread (op.getType ());
2247
+ Type valueElemTy = typeConverter->convertType (getElementTypeOrSelf (opType));
2248
+ unsigned numElems = getTotalElemsPerThread (opType);
2242
2249
unsigned vec = getVectorSize (ptr);
2243
2250
if (llMask)
2244
2251
vec = std::min<size_t >(vec, getMaskAlignment (mask));
@@ -2249,7 +2256,7 @@ struct LoadOpConversion
2249
2256
2250
2257
if (isTensorPointerType (ptr.getType ())) {
2251
2258
// fallback to gather load.
2252
- auto tensorType = cast<RankedTensorType>(op. getType () );
2259
+ auto tensorType = cast<RankedTensorType>(opType );
2253
2260
std::tie (ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr (
2254
2261
loc, adaptor.getPtr (), tensorType, valueElemTy, rewriter,
2255
2262
op.getBoundaryCheck (), op.getPadding ());
@@ -2396,7 +2403,7 @@ struct LoadOpConversion
2396
2403
}
2397
2404
} // end vec
2398
2405
2399
- Type llvmResultStructTy = typeConverter->convertType (op. getType () );
2406
+ Type llvmResultStructTy = typeConverter->convertType (opType );
2400
2407
Value resultStruct = packLLElements (loc, typeConverter, loadedVals,
2401
2408
rewriter, llvmResultStructTy);
2402
2409
rewriter.replaceOp (op, {resultStruct});
0 commit comments