Skip to content

Commit f06f90b

Browse files
committed
fixup op type override
1 parent 2b354d2 commit f06f90b

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,8 +2239,8 @@ struct LoadOpConversion
22392239
auto opType = op.getType();
22402240
// TODO: Override the OpType since conversion is still happening during Load
22412241
// lowering. Once we materialize ConvertLayoutOp this can be removed.
2242-
if (auto tensorTy = dyn_cast<RankedTensorType>(opType);
2243-
hasSubgroup2DBlockEncoding(tensorTy))
2242+
auto tensorTy = dyn_cast<RankedTensorType>(opType);
2243+
if (tensorTy && hasSubgroup2DBlockEncoding(tensorTy))
22442244
opType = getDpasTypeFromCVTOp(op.getResult());
22452245

22462246
// Determine the vectorization size
@@ -2256,9 +2256,11 @@ struct LoadOpConversion
22562256

22572257
if (isTensorPointerType(ptr.getType())) {
22582258
// fallback to gather load.
2259-
auto tensorType = cast<RankedTensorType>(opType);
2259+
// make sure we use the modified opType from above, "seeing through" any
2260+
// post-subgroup 2d block encoding CVT.
2261+
auto blockPtrTensorType = cast<RankedTensorType>(opType);
22602262
std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr(
2261-
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
2263+
loc, adaptor.getPtr(), blockPtrTensorType, valueElemTy, rewriter,
22622264
op.getBoundaryCheck(), op.getPadding());
22632265
} else {
22642266
Value other = op.getOther();

0 commit comments

Comments
 (0)