Skip to content

Commit 3a39e40

Browse files
committed
fixup handling of tensor ptrs when lowering to gather load (with subgroup 2d block layout )
1 parent ebf5409 commit 3a39e40

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,15 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
343343
: getDotEncoding(tensorTy).value().getParent());
344344
}
345345

346+
static RankedTensorType getDpasTypeFromCVTOp(Value opResult) {
347+
for (OpOperand user : opResult.getUsers()) {
348+
if (auto cvt = dyn_cast<ConvertLayoutOp>(user.getOwner())) {
349+
return cast<RankedTensorType>(cvt.getResult().getType());
350+
}
351+
}
352+
llvm_unreachable("expected to find a cvt op with dpas layout");
353+
}
354+
346355
// Returns the pitch (stride in bytes) of \p ptr.
347356
Value getPitch(ConversionPatternRewriter &rewriter, Value ptr,
348357
const std::map<SmallVector<unsigned>, Value> &ptrs,
@@ -1418,16 +1427,6 @@ struct LoadOpConversion
14181427

14191428
const bool memoryRowMajor = isMemoryRowMajor(op);
14201429

1421-
auto getDpasTypeFromCVTOp = [&](Value opResult) -> RankedTensorType {
1422-
for (OpOperand user : opResult.getUsers()) {
1423-
if (auto cvt = dyn_cast<ConvertLayoutOp>(user.getOwner())) {
1424-
return cast<RankedTensorType>(cvt.getResult().getType());
1425-
// return getDpasLayout(cvt.getResult().getType());
1426-
}
1427-
}
1428-
llvm_unreachable("expected to find a cvt op with dpas layout");
1429-
};
1430-
14311430
auto dpasTensorType = hasSubgroup2DBlockEncoding(tensorType)
14321431
? getDpasTypeFromCVTOp(op.getResult())
14331432
: tensorType;
@@ -2213,6 +2212,8 @@ struct LoadOpConversion
22132212
}
22142213

22152214
Type llvmResultStructTy = typeConverter->convertType(op.getType());
2215+
LLVM_DEBUG(llvm::dbgs() << "Packing load result in struct "
2216+
<< llvmResultStructTy << "\n");
22162217
Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals,
22172218
rewriter, llvmResultStructTy);
22182219
rewriter.replaceOp(op, {resultStruct});
@@ -2235,10 +2236,16 @@ struct LoadOpConversion
22352236
Value mask = op.getMask();
22362237
Value llMask = adaptor.getMask();
22372238

2239+
auto opType = op.getType();
2240+
// TODO: Override the OpType since conversion is still happening during Load
2241+
// lowering. Once we materialize ConvertLayoutOp this can be removed.
2242+
if (auto tensorTy = dyn_cast<RankedTensorType>(opType);
2243+
hasSubgroup2DBlockEncoding(tensorTy))
2244+
opType = getDpasTypeFromCVTOp(op.getResult());
2245+
22382246
// Determine the vectorization size
2239-
Type valueElemTy =
2240-
typeConverter->convertType(getElementTypeOrSelf(op.getType()));
2241-
unsigned numElems = getTotalElemsPerThread(op.getType());
2247+
Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(opType));
2248+
unsigned numElems = getTotalElemsPerThread(opType);
22422249
unsigned vec = getVectorSize(ptr);
22432250
if (llMask)
22442251
vec = std::min<size_t>(vec, getMaskAlignment(mask));
@@ -2249,7 +2256,7 @@ struct LoadOpConversion
22492256

22502257
if (isTensorPointerType(ptr.getType())) {
22512258
// fallback to gather load.
2252-
auto tensorType = cast<RankedTensorType>(op.getType());
2259+
auto tensorType = cast<RankedTensorType>(opType);
22532260
std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr(
22542261
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
22552262
op.getBoundaryCheck(), op.getPadding());
@@ -2396,7 +2403,7 @@ struct LoadOpConversion
23962403
}
23972404
} // end vec
23982405

2399-
Type llvmResultStructTy = typeConverter->convertType(op.getType());
2406+
Type llvmResultStructTy = typeConverter->convertType(opType);
24002407
Value resultStruct = packLLElements(loc, typeConverter, loadedVals,
24012408
rewriter, llvmResultStructTy);
24022409
rewriter.replaceOp(op, {resultStruct});

0 commit comments

Comments
 (0)