@@ -977,16 +977,23 @@ struct LoadOpToBlockIOConversion
977977 LoadOpToBlockIOConversion (
978978 LLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo,
979979 const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
980- PatternBenefit benefit, bool oneMatrixPerLoadForBT,
981- bool useTileLoadLinearLayout)
980+ PatternBenefit benefit, bool useTileLoadLinearLayout)
982981 : ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
983982 BlockIOConversionBase (targetInfo, axisAnalysisPass),
984- oneMatrixPerLoadForBT (oneMatrixPerLoadForBT),
985983 useTileLoadLinearLayout (useTileLoadLinearLayout) {}
986984
987985 LogicalResult
988986 rewriteTensorPointerLoad (triton::LoadOp op, OpAdaptor adaptor,
989987 ConversionPatternRewriter &rewriter) const {
988+ // FIXME: Remove once IGC can split large 2D block loads.
989+ std::optional<bool > oneMatrixPerLoadForBT =
990+ mlir::triton::tools::isEnvValueBool (mlir::triton::tools::getStrEnv (
991+ " TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT" ));
992+ if (!oneMatrixPerLoadForBT.has_value ())
993+ oneMatrixPerLoadForBT =
994+ op->hasAttr (triton::gpu::intel::TritonIntelGPUDialect::
995+ getOneMatrixPerLoadAttrName ());
996+
990997 Value ptr = op.getPtr ();
991998 assert (isTensorPointerType (ptr.getType ()) &&
992999 " Expecting tensor pointer type" );
@@ -1342,7 +1349,7 @@ struct LoadOpToBlockIOConversion
13421349 if (!usePackedType)
13431350 return failure ();
13441351
1345- if (oneMatrixPerLoadForBT) {
1352+ if (* oneMatrixPerLoadForBT) {
13461353 // Only load 1 operand per inst on row.
13471354 numOperandsPer2DLoadM = 1 ;
13481355 tileHeight = elemsPerDPASInst[threadOrder[rank - 2 ]];
@@ -1391,7 +1398,7 @@ struct LoadOpToBlockIOConversion
13911398 tileLayout *= LinearLayout::identity1D (numOperandsOuterDimPerLoad,
13921399 kIteration , dimOuterStr);
13931400 tileLayout *=
1394- LinearLayout::identity1D (isTransposeRequired && oneMatrixPerLoadForBT
1401+ LinearLayout::identity1D (isTransposeRequired && * oneMatrixPerLoadForBT
13951402 ? 1
13961403 : numOperandsInnerDimPerLoad,
13971404 kIteration , dimInnerStr);
@@ -2466,7 +2473,6 @@ struct LoadOpToBlockIOConversion
24662473 }
24672474
24682475private:
2469- bool oneMatrixPerLoadForBT;
24702476 bool useTileLoadLinearLayout;
24712477};
24722478
@@ -3498,15 +3504,14 @@ void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns(
34983504 LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
34993505 RewritePatternSet &patterns,
35003506 const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
3501- PatternBenefit benefit, bool oneMatrixPerLoadForBT,
3502- bool useTileLoadLinearLayout) {
3507+ PatternBenefit benefit, bool useTileLoadLinearLayout) {
35033508 patterns.add <AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
35043509 StoreOpConversion, PrefetchOpConversion>(
35053510 typeConverter, targetInfo, axisInfoAnalysis, benefit);
35063511 // BlockIO is more efficient than gather load or scatter store.
35073512 patterns.add <LoadOpToBlockIOConversion>(
35083513 typeConverter, targetInfo, axisInfoAnalysis, benefit.getBenefit () + 2 ,
3509- oneMatrixPerLoadForBT, useTileLoadLinearLayout);
3514+ useTileLoadLinearLayout);
35103515 patterns.add <StoreOpToBlockIOConversion>(
35113516 typeConverter, targetInfo, axisInfoAnalysis, benefit.getBenefit () + 2 );
35123517}
0 commit comments