@@ -2239,8 +2239,8 @@ struct LoadOpConversion
2239
2239
auto opType = op.getType ();
2240
2240
// TODO: Override the OpType since conversion is still happening during Load
2241
2241
// lowering. Once we materialize ConvertLayoutOp this can be removed.
2242
- if ( auto tensorTy = dyn_cast<RankedTensorType>(opType);
2243
- hasSubgroup2DBlockEncoding (tensorTy))
2242
+ auto tensorTy = dyn_cast<RankedTensorType>(opType);
2243
+ if (tensorTy && hasSubgroup2DBlockEncoding (tensorTy))
2244
2244
opType = getDpasTypeFromCVTOp (op.getResult ());
2245
2245
2246
2246
// Determine the vectorization size
@@ -2256,9 +2256,11 @@ struct LoadOpConversion
2256
2256
2257
2257
if (isTensorPointerType (ptr.getType ())) {
2258
2258
// fallback to gather load.
2259
- auto tensorType = cast<RankedTensorType>(opType);
2259
+ // make sure we use the modified opType from above, "seeing through" any
2260
+ // post-subgroup 2d block encoding CVT.
2261
+ auto blockPtrTensorType = cast<RankedTensorType>(opType);
2260
2262
std::tie (ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr (
2261
- loc, adaptor.getPtr (), tensorType , valueElemTy, rewriter,
2263
+ loc, adaptor.getPtr (), blockPtrTensorType , valueElemTy, rewriter,
2262
2264
op.getBoundaryCheck (), op.getPadding ());
2263
2265
} else {
2264
2266
Value other = op.getOther ();
0 commit comments