@@ -308,9 +308,66 @@ struct LoadStoreConversionBase {
308308 const triton::intel::TargetInfo &targetInfo;
309309};
310310
311+ struct BlockIOConversionBase : public LoadStoreConversionBase {
312+ explicit BlockIOConversionBase (
313+ const triton::intel::TargetInfo &targetInfo,
314+ const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass)
315+ : LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
316+
317+ // Determine whether the given LoadOp can be lowered to using block IO
318+ // instructions.
319+ bool isLoadCandidate (triton::LoadOp op) const {
320+ Attribute blockIOAttr =
321+ op->getAttr (TritonIntelGPUDialect::getBlockIOAttrName ());
322+ if (!blockIOAttr)
323+ return false ;
324+
325+ // Only lower loadOp with dpas layout encoding.
326+ auto tensorTy = cast<RankedTensorType>(op.getType ());
327+ return hasDpasEncoding (tensorTy) || hasDotDpasEncoding (tensorTy);
328+ }
329+
330+ template <
331+ typename OpTy,
332+ std::enable_if_t <llvm::is_one_of<OpTy, triton::gpu::intel::PrefetchOp,
333+ triton::LoadOp>::value,
334+ bool > = true >
335+ bool isMemoryRowMajor (OpTy op) const {
336+ Attribute blockIOAttr =
337+ op->getAttr (TritonIntelGPUDialect::getBlockIOAttrName ());
338+ assert (blockIOAttr && " Expecting block IO attribute" );
339+
340+ // TODO: To support more layouts on memory:
341+ // https://github.com/intel/intel-xpu-backend-for-triton/issues/4057.
342+ // Only support rank 2 dot layout, either row major or column major.
343+ StringRef memoryLayoutInfo = cast<StringAttr>(blockIOAttr).getValue ();
344+ assert ((memoryLayoutInfo == " row_major" ||
345+ memoryLayoutInfo == " column_major" ) &&
346+ " Only row_major or column_major is supported" );
347+ return memoryLayoutInfo == " row_major" ;
348+ }
349+
350+ DpasEncodingAttr::OpIdx getOpIdx (RankedTensorType tensorTy) const {
351+ if (hasDpasEncoding (tensorTy))
352+ return DpasEncodingAttr::OpIdx::OperandC;
353+
354+ assert (hasDotDpasEncoding (tensorTy) && " Expecting dot layout" );
355+ DotOperandEncodingAttr dotLayout = getDotEncoding (tensorTy).value ();
356+ return static_cast <DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx ());
357+ }
358+
359+ DpasEncodingAttr getDpasLayout (RankedTensorType tensorTy) const {
360+ Attribute encoding = tensorTy.getEncoding ();
361+ return cast<DpasEncodingAttr>(
362+ hasDpasEncoding (tensorTy)
363+ ? encoding
364+ : getDotEncoding (tensorTy).value ().getParent ());
365+ }
366+ };
367+
311368struct PrefetchOpConversion
312369 : public ConvertTritonGPUOpToLLVMPattern<triton::gpu::intel::PrefetchOp>,
313- public LoadStoreConversionBase {
370+ public BlockIOConversionBase {
314371 using ConvertTritonGPUOpToLLVMPattern<
315372 triton::gpu::intel::PrefetchOp>::ConvertTritonGPUOpToLLVMPattern;
316373
@@ -320,7 +377,7 @@ struct PrefetchOpConversion
320377 PatternBenefit benefit)
321378 : ConvertTritonGPUOpToLLVMPattern<triton::gpu::intel::PrefetchOp>(
322379 converter, benefit),
323- LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
380+ BlockIOConversionBase (targetInfo, axisAnalysisPass) {}
324381
325382 LogicalResult
326383 matchAndRewrite (triton::gpu::intel::PrefetchOp op, OpAdaptor adaptor,
@@ -337,7 +394,6 @@ struct PrefetchOpConversion
337394 rewriteTensorPointerPrefetch (triton::gpu::intel::PrefetchOp op,
338395 OpAdaptor adaptor,
339396 ConversionPatternRewriter &rewriter) const {
340-
341397 Attribute blockIOAttr =
342398 op->getAttr (TritonIntelGPUDialect::getBlockIOAttrName ());
343399 if (!blockIOAttr) {
@@ -347,14 +403,6 @@ struct PrefetchOpConversion
347403 return success ();
348404 }
349405
350- // Only support rank 2 block pointer, either row major or column major.
351- StringRef memoryLayoutInfo = cast<StringAttr>(blockIOAttr).getValue ();
352- assert ((memoryLayoutInfo == " row_major" ||
353- memoryLayoutInfo == " column_major" ) &&
354- " Only row_major or column_major is supported" );
355-
356- const bool memoryRowMajor = (memoryLayoutInfo == " row_major" );
357-
358406 auto mod = rewriter.getBlock ()->getParent ()->getParentOfType <ModuleOp>();
359407 Location loc = op.getLoc ();
360408 auto b = TritonLLVMOpBuilder (loc, rewriter);
@@ -365,6 +413,7 @@ struct PrefetchOpConversion
365413 const ArrayRef<int64_t > shapeRef = tensorType.getShape ();
366414 SmallVector<int64_t > tensorShape{shapeRef.begin (), shapeRef.end ()};
367415
416+ const bool memoryRowMajor = isMemoryRowMajor (op);
368417 if (!memoryRowMajor) {
369418 // Swap the shape to make it row major and then get the tiling
370419 // size base on row major shape.
@@ -485,7 +534,7 @@ struct PrefetchOpConversion
485534
486535struct LoadOpToBlockIOConversion
487536 : public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
488- public LoadStoreConversionBase {
537+ public BlockIOConversionBase {
489538 using ConvertTritonGPUOpToLLVMPattern<
490539 triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
491540
@@ -496,14 +545,12 @@ struct LoadOpToBlockIOConversion
496545 const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
497546 PatternBenefit benefit)
498547 : ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
499- LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
548+ BlockIOConversionBase (targetInfo, axisAnalysisPass) {}
500549
501550 LogicalResult
502551 matchAndRewrite (triton::LoadOp op, OpAdaptor adaptor,
503552 ConversionPatternRewriter &rewriter) const final {
504- Attribute blockIOAttr =
505- op->getAttr (TritonIntelGPUDialect::getBlockIOAttrName ());
506- if (!blockIOAttr)
553+ if (!isLoadCandidate (op))
507554 return failure ();
508555
509556 Location loc = op.getLoc ();
@@ -515,31 +562,10 @@ struct LoadOpToBlockIOConversion
515562 Type resultType = op.getType ();
516563 auto tensorType = cast<RankedTensorType>(resultType);
517564
518- const bool hasDpasLayout = hasDpasEncoding (tensorType);
519- if (!hasDpasLayout && !hasDotDpasEncoding (tensorType))
520- return failure ();
521-
522- // Only lower loadOp with dpas layout encoding.
523- auto encoding = tensorType.getEncoding ();
524-
525- // TODO: To support more layouts on memory.
526- // Only support rank 2 dot layout, either row major or column major.
527- StringRef memoryLayoutInfo = cast<StringAttr>(blockIOAttr).getValue ();
528- assert ((memoryLayoutInfo == " row_major" ||
529- memoryLayoutInfo == " column_major" ) &&
530- " Only row_major or column_major is supported" );
531- const bool memoryRowMajor = (memoryLayoutInfo == " row_major" );
532-
533- auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
534- if (hasDpasLayout)
535- return DpasEncodingAttr::OpIdx::OperandC;
536-
537- assert (hasDotDpasEncoding (tensorType) && " Expecting dot layout" );
538- DotOperandEncodingAttr dotLayout = getDotEncoding (tensorType).value ();
539- return static_cast <DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx ());
540- };
541- DpasEncodingAttr::OpIdx opIdx = getOpIdx ();
565+ const bool memoryRowMajor = isMemoryRowMajor (op);
566+ DpasEncodingAttr::OpIdx opIdx = getOpIdx (tensorType);
542567
568+ Attribute encoding = tensorType.getEncoding ();
543569 std::optional<LinearLayout> llEncoding =
544570 cast<DistributedEncodingTrait>(encoding).toLinearLayout (
545571 tensorType.getShape ());
@@ -556,12 +582,7 @@ struct LoadOpToBlockIOConversion
556582
557583 Type eltTy = tensorType.getElementType ();
558584 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
559-
560- auto dpasLayout = hasDpasLayout
561- ? cast<DpasEncodingAttr>(encoding)
562- : cast<DpasEncodingAttr>(
563- getDotEncoding (tensorType).value ().getParent ());
564-
585+ DpasEncodingAttr dpasLayout = getDpasLayout (tensorType);
565586 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
566587 unsigned numElems = getTotalElemsPerThread (resultType);
567588 SmallVector<int64_t > numReps =
@@ -983,7 +1004,7 @@ struct LoadOpToBlockIOConversion
9831004
9841005struct LoadOpConversion
9851006 : public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
986- public LoadStoreConversionBase {
1007+ public BlockIOConversionBase {
9871008 using ConvertTritonGPUOpToLLVMPattern<
9881009 triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
9891010
@@ -995,7 +1016,7 @@ struct LoadOpConversion
9951016 PatternBenefit benefit, bool oneMatrixPerLoadForBT,
9961017 bool useTileLoadLinearLayout)
9971018 : ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
998- LoadStoreConversionBase (targetInfo, axisAnalysisPass),
1019+ BlockIOConversionBase (targetInfo, axisAnalysisPass),
9991020 oneMatrixPerLoadForBT (oneMatrixPerLoadForBT),
10001021 useTileLoadLinearLayout (useTileLoadLinearLayout) {}
10011022
@@ -1004,7 +1025,10 @@ struct LoadOpConversion
10041025 ConversionPatternRewriter &rewriter) const {
10051026 Value ptr = op.getPtr ();
10061027 assert (isTensorPointerType (ptr.getType ()) &&
1007- " Expecting tensor of pointer type" );
1028+ " Expecting tensor pointer type" );
1029+
1030+ if (!isLoadCandidate (op))
1031+ return failure ();
10081032
10091033 Location loc = op.getLoc ();
10101034 auto b = TritonLLVMOpBuilder (loc, rewriter);
@@ -1013,37 +1037,13 @@ struct LoadOpConversion
10131037 Type resultType = op.getType ();
10141038 auto tensorType = cast<RankedTensorType>(resultType);
10151039
1016- // Only lower loadOp with dpas layout encoding.
1017- auto encoding = tensorType.getEncoding ();
1018- const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
1019- if (!hasDpasLayout && !hasDotDpasEncoding (tensorType))
1020- return failure ();
1021-
1022- Attribute blockIOAttr =
1023- op->getAttr (TritonIntelGPUDialect::getBlockIOAttrName ());
1024- if (!blockIOAttr)
1025- return failure ();
1026-
1027- // Only support rank 2 dot layout, either row major or column major.
1028- StringRef memoryLayoutInfo = cast<StringAttr>(blockIOAttr).getValue ();
1029- assert ((memoryLayoutInfo == " row_major" ||
1030- memoryLayoutInfo == " column_major" ) &&
1031- " Only row_major or column_major is supported" );
1032- const bool memoryRowMajor = (memoryLayoutInfo == " row_major" );
1033-
1034- auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
1035- if (hasDpasLayout) {
1036- return DpasEncodingAttr::OpIdx::OperandC;
1037- } else {
1038- auto dotLayout = getDotEncoding (tensorType).value ();
1039- return static_cast <DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx ());
1040- }
1041- };
1042- auto opIdx = getOpIdx ();
1040+ const bool memoryRowMajor = isMemoryRowMajor (op);
1041+ DpasEncodingAttr::OpIdx opIdx = getOpIdx (tensorType);
10431042
10441043 LLVM_DEBUG (llvm::dbgs () << " Tensor type for op " << int (opIdx) << " : "
10451044 << tensorType << " \n " );
10461045
1046+ Attribute encoding = tensorType.getEncoding ();
10471047 std::optional<LinearLayout> llEncoding =
10481048 cast<DistributedEncodingTrait>(encoding).toLinearLayout (
10491049 tensorType.getShape ());
@@ -1061,12 +1061,7 @@ struct LoadOpConversion
10611061
10621062 Type eltTy = tensorType.getElementType ();
10631063 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
1064-
1065- auto dpasLayout = hasDpasLayout
1066- ? cast<DpasEncodingAttr>(encoding)
1067- : cast<DpasEncodingAttr>(
1068- getDotEncoding (tensorType).value ().getParent ());
1069-
1064+ DpasEncodingAttr dpasLayout = getDpasLayout (tensorType);
10701065 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
10711066 unsigned numElems = getTotalElemsPerThread (resultType);
10721067 SmallVector<int64_t > numReps =
@@ -1084,7 +1079,7 @@ struct LoadOpConversion
10841079 SmallVector<Value> multiDimWarpId =
10851080 delinearize (rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder);
10861081
1087- if (hasDpasLayout ) {
1082+ if (opIdx == DpasEncodingAttr::OpIdx::OperandC ) {
10881083 // A block load with the DPAS layout but without the DotDpasLayout is
10891084 // expected to follow the ordering of the DPAS output. For a 2D block
10901085 // load, the rows are distributed across work items/SIMD lanes and the
0 commit comments