Skip to content

Commit ef82ff7

Browse files
[NFI][LoadStoreOpToLLVM] Create BlockIOConversionBase for common BlockIO functions (#4053)
This is an initial attempt to reduce the obvious code duplication, there should be more duplication that can be reduced. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 8f04531 commit ef82ff7

File tree

1 file changed

+78
-83
lines changed

1 file changed

+78
-83
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 78 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,66 @@ struct LoadStoreConversionBase {
308308
const triton::intel::TargetInfo &targetInfo;
309309
};
310310

311+
struct BlockIOConversionBase : public LoadStoreConversionBase {
312+
explicit BlockIOConversionBase(
313+
const triton::intel::TargetInfo &targetInfo,
314+
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass)
315+
: LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
316+
317+
// Determine whether the given LoadOp can be lowered to using block IO
318+
// instructions.
319+
bool isLoadCandidate(triton::LoadOp op) const {
320+
Attribute blockIOAttr =
321+
op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName());
322+
if (!blockIOAttr)
323+
return false;
324+
325+
// Only lower loadOp with dpas layout encoding.
326+
auto tensorTy = cast<RankedTensorType>(op.getType());
327+
return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy);
328+
}
329+
330+
template <
331+
typename OpTy,
332+
std::enable_if_t<llvm::is_one_of<OpTy, triton::gpu::intel::PrefetchOp,
333+
triton::LoadOp>::value,
334+
bool> = true>
335+
bool isMemoryRowMajor(OpTy op) const {
336+
Attribute blockIOAttr =
337+
op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName());
338+
assert(blockIOAttr && "Expecting block IO attribute");
339+
340+
// TODO: To support more layouts on memory:
341+
// https://github.com/intel/intel-xpu-backend-for-triton/issues/4057.
342+
// Only support rank 2 dot layout, either row major or column major.
343+
StringRef memoryLayoutInfo = cast<StringAttr>(blockIOAttr).getValue();
344+
assert((memoryLayoutInfo == "row_major" ||
345+
memoryLayoutInfo == "column_major") &&
346+
"Only row_major or column_major is supported");
347+
return memoryLayoutInfo == "row_major";
348+
}
349+
350+
DpasEncodingAttr::OpIdx getOpIdx(RankedTensorType tensorTy) const {
351+
if (hasDpasEncoding(tensorTy))
352+
return DpasEncodingAttr::OpIdx::OperandC;
353+
354+
assert(hasDotDpasEncoding(tensorTy) && "Expecting dot layout");
355+
DotOperandEncodingAttr dotLayout = getDotEncoding(tensorTy).value();
356+
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
357+
}
358+
359+
DpasEncodingAttr getDpasLayout(RankedTensorType tensorTy) const {
360+
Attribute encoding = tensorTy.getEncoding();
361+
return cast<DpasEncodingAttr>(
362+
hasDpasEncoding(tensorTy)
363+
? encoding
364+
: getDotEncoding(tensorTy).value().getParent());
365+
}
366+
};
367+
311368
struct PrefetchOpConversion
312369
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::intel::PrefetchOp>,
313-
public LoadStoreConversionBase {
370+
public BlockIOConversionBase {
314371
using ConvertTritonGPUOpToLLVMPattern<
315372
triton::gpu::intel::PrefetchOp>::ConvertTritonGPUOpToLLVMPattern;
316373

@@ -320,7 +377,7 @@ struct PrefetchOpConversion
320377
PatternBenefit benefit)
321378
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::intel::PrefetchOp>(
322379
converter, benefit),
323-
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
380+
BlockIOConversionBase(targetInfo, axisAnalysisPass) {}
324381

325382
LogicalResult
326383
matchAndRewrite(triton::gpu::intel::PrefetchOp op, OpAdaptor adaptor,
@@ -337,7 +394,6 @@ struct PrefetchOpConversion
337394
rewriteTensorPointerPrefetch(triton::gpu::intel::PrefetchOp op,
338395
OpAdaptor adaptor,
339396
ConversionPatternRewriter &rewriter) const {
340-
341397
Attribute blockIOAttr =
342398
op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName());
343399
if (!blockIOAttr) {
@@ -347,14 +403,6 @@ struct PrefetchOpConversion
347403
return success();
348404
}
349405

350-
// Only support rank 2 block pointer, either row major or column major.
351-
StringRef memoryLayoutInfo = cast<StringAttr>(blockIOAttr).getValue();
352-
assert((memoryLayoutInfo == "row_major" ||
353-
memoryLayoutInfo == "column_major") &&
354-
"Only row_major or column_major is supported");
355-
356-
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");
357-
358406
auto mod = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
359407
Location loc = op.getLoc();
360408
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -365,6 +413,7 @@ struct PrefetchOpConversion
365413
const ArrayRef<int64_t> shapeRef = tensorType.getShape();
366414
SmallVector<int64_t> tensorShape{shapeRef.begin(), shapeRef.end()};
367415

416+
const bool memoryRowMajor = isMemoryRowMajor(op);
368417
if (!memoryRowMajor) {
369418
// Swap the shape to make it row major and then get the tiling
370419
// size base on row major shape.
@@ -485,7 +534,7 @@ struct PrefetchOpConversion
485534

486535
struct LoadOpToBlockIOConversion
487536
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
488-
public LoadStoreConversionBase {
537+
public BlockIOConversionBase {
489538
using ConvertTritonGPUOpToLLVMPattern<
490539
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
491540

@@ -496,14 +545,12 @@ struct LoadOpToBlockIOConversion
496545
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
497546
PatternBenefit benefit)
498547
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
499-
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
548+
BlockIOConversionBase(targetInfo, axisAnalysisPass) {}
500549

501550
LogicalResult
502551
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
503552
ConversionPatternRewriter &rewriter) const final {
504-
Attribute blockIOAttr =
505-
op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName());
506-
if (!blockIOAttr)
553+
if (!isLoadCandidate(op))
507554
return failure();
508555

509556
Location loc = op.getLoc();
@@ -515,31 +562,10 @@ struct LoadOpToBlockIOConversion
515562
Type resultType = op.getType();
516563
auto tensorType = cast<RankedTensorType>(resultType);
517564

518-
const bool hasDpasLayout = hasDpasEncoding(tensorType);
519-
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
520-
return failure();
521-
522-
// Only lower loadOp with dpas layout encoding.
523-
auto encoding = tensorType.getEncoding();
524-
525-
// TODO: To support more layouts on memory.
526-
// Only support rank 2 dot layout, either row major or column major.
527-
StringRef memoryLayoutInfo = cast<StringAttr>(blockIOAttr).getValue();
528-
assert((memoryLayoutInfo == "row_major" ||
529-
memoryLayoutInfo == "column_major") &&
530-
"Only row_major or column_major is supported");
531-
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");
532-
533-
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
534-
if (hasDpasLayout)
535-
return DpasEncodingAttr::OpIdx::OperandC;
536-
537-
assert(hasDotDpasEncoding(tensorType) && "Expecting dot layout");
538-
DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
539-
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
540-
};
541-
DpasEncodingAttr::OpIdx opIdx = getOpIdx();
565+
const bool memoryRowMajor = isMemoryRowMajor(op);
566+
DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType);
542567

568+
Attribute encoding = tensorType.getEncoding();
543569
std::optional<LinearLayout> llEncoding =
544570
cast<DistributedEncodingTrait>(encoding).toLinearLayout(
545571
tensorType.getShape());
@@ -556,12 +582,7 @@ struct LoadOpToBlockIOConversion
556582

557583
Type eltTy = tensorType.getElementType();
558584
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
559-
560-
auto dpasLayout = hasDpasLayout
561-
? cast<DpasEncodingAttr>(encoding)
562-
: cast<DpasEncodingAttr>(
563-
getDotEncoding(tensorType).value().getParent());
564-
585+
DpasEncodingAttr dpasLayout = getDpasLayout(tensorType);
565586
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
566587
unsigned numElems = getTotalElemsPerThread(resultType);
567588
SmallVector<int64_t> numReps =
@@ -983,7 +1004,7 @@ struct LoadOpToBlockIOConversion
9831004

9841005
struct LoadOpConversion
9851006
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
986-
public LoadStoreConversionBase {
1007+
public BlockIOConversionBase {
9871008
using ConvertTritonGPUOpToLLVMPattern<
9881009
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
9891010

@@ -995,7 +1016,7 @@ struct LoadOpConversion
9951016
PatternBenefit benefit, bool oneMatrixPerLoadForBT,
9961017
bool useTileLoadLinearLayout)
9971018
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
998-
LoadStoreConversionBase(targetInfo, axisAnalysisPass),
1019+
BlockIOConversionBase(targetInfo, axisAnalysisPass),
9991020
oneMatrixPerLoadForBT(oneMatrixPerLoadForBT),
10001021
useTileLoadLinearLayout(useTileLoadLinearLayout) {}
10011022

@@ -1004,7 +1025,10 @@ struct LoadOpConversion
10041025
ConversionPatternRewriter &rewriter) const {
10051026
Value ptr = op.getPtr();
10061027
assert(isTensorPointerType(ptr.getType()) &&
1007-
"Expecting tensor of pointer type");
1028+
"Expecting tensor pointer type");
1029+
1030+
if (!isLoadCandidate(op))
1031+
return failure();
10081032

10091033
Location loc = op.getLoc();
10101034
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -1013,37 +1037,13 @@ struct LoadOpConversion
10131037
Type resultType = op.getType();
10141038
auto tensorType = cast<RankedTensorType>(resultType);
10151039

1016-
// Only lower loadOp with dpas layout encoding.
1017-
auto encoding = tensorType.getEncoding();
1018-
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
1019-
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
1020-
return failure();
1021-
1022-
Attribute blockIOAttr =
1023-
op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName());
1024-
if (!blockIOAttr)
1025-
return failure();
1026-
1027-
// Only support rank 2 dot layout, either row major or column major.
1028-
StringRef memoryLayoutInfo = cast<StringAttr>(blockIOAttr).getValue();
1029-
assert((memoryLayoutInfo == "row_major" ||
1030-
memoryLayoutInfo == "column_major") &&
1031-
"Only row_major or column_major is supported");
1032-
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");
1033-
1034-
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
1035-
if (hasDpasLayout) {
1036-
return DpasEncodingAttr::OpIdx::OperandC;
1037-
} else {
1038-
auto dotLayout = getDotEncoding(tensorType).value();
1039-
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
1040-
}
1041-
};
1042-
auto opIdx = getOpIdx();
1040+
const bool memoryRowMajor = isMemoryRowMajor(op);
1041+
DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType);
10431042

10441043
LLVM_DEBUG(llvm::dbgs() << "Tensor type for op " << int(opIdx) << ": "
10451044
<< tensorType << "\n");
10461045

1046+
Attribute encoding = tensorType.getEncoding();
10471047
std::optional<LinearLayout> llEncoding =
10481048
cast<DistributedEncodingTrait>(encoding).toLinearLayout(
10491049
tensorType.getShape());
@@ -1061,12 +1061,7 @@ struct LoadOpConversion
10611061

10621062
Type eltTy = tensorType.getElementType();
10631063
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
1064-
1065-
auto dpasLayout = hasDpasLayout
1066-
? cast<DpasEncodingAttr>(encoding)
1067-
: cast<DpasEncodingAttr>(
1068-
getDotEncoding(tensorType).value().getParent());
1069-
1064+
DpasEncodingAttr dpasLayout = getDpasLayout(tensorType);
10701065
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
10711066
unsigned numElems = getTotalElemsPerThread(resultType);
10721067
SmallVector<int64_t> numReps =
@@ -1084,7 +1079,7 @@ struct LoadOpConversion
10841079
SmallVector<Value> multiDimWarpId =
10851080
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder);
10861081

1087-
if (hasDpasLayout) {
1082+
if (opIdx == DpasEncodingAttr::OpIdx::OperandC) {
10881083
// A block load with the DPAS layout but without the DotDpasLayout is
10891084
// expected to follow the ordering of the DPAS output. For a 2D block
10901085
// load, the rows are distributed across work items/SIMD lanes and the

0 commit comments

Comments
 (0)