@@ -52,7 +52,8 @@ constexpr int64_t kPreferredCopyNumBits = 128;
5252
5353LogicalResult setDataTiledMmaInnerTiledLoweringConfig (
5454 IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
55- Operation *op, IREE::Codegen::UKernelDescriptorAttr ukernelConfig) {
55+ Operation *op, IREE::Codegen::UKernelDescriptorAttr ukernelConfig,
56+ std::optional<uint64_t > prefetchNumStages) {
5657 auto multiMmaOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op);
5758 if (!multiMmaOp) {
5859 return failure ();
@@ -109,11 +110,12 @@ LogicalResult setDataTiledMmaInnerTiledLoweringConfig(
109110 DictionaryAttr configDict = b.getDictionaryAttr (attrs);
110111 auto loweringConfig = IREE::GPU::LoweringConfigAttr::get (context, configDict);
111112
112- // Don 't add any special padding or prefetching, since the data-tiled layout
113- // is already what we want.
113+ // By default, don 't add any special padding or prefetching, since the
114+ // data-tiled layout is already what we want.
114115 SmallVector<NamedAttribute, 1 > pipelineAttrs;
116+ int64_t prefetchStages = prefetchNumStages.value_or (0 );
115117 auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get (
116- context, /* prefetchNumStages=*/ 0 ,
118+ context, /* prefetchNumStages=*/ prefetchStages ,
117119 /* no_reduce_shared_memory_bank_conflicts=*/ true ,
118120 /* use_igemm_convolution=*/ false ,
119121 /* reorder_workgroups_strategy=*/ std::nullopt );
@@ -1014,7 +1016,8 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
10141016
10151017LogicalResult setIGEMMConvolutionLoweringConfig (
10161018 IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
1017- Operation *op, bool useDirectLoad, bool padConv) {
1019+ Operation *op, bool useDirectLoad, bool padConv,
1020+ std::optional<uint64_t > prefetchNumStages) {
10181021 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
10191022 if (!linalgOp || !linalg::isaConvolutionOpInterface (linalgOp)) {
10201023 return failure ();
@@ -1100,9 +1103,11 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
11001103 LoweringConfigAttr loweringConfig = configAndWgSize->first ;
11011104
11021105 SmallVector<NamedAttribute, 1 > pipelineAttrs;
1106+ // Default to 2 stages if not specified.
1107+ int64_t prefetchStages = prefetchNumStages.value_or (2 );
11031108 auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get (
11041109 linalgOp->getContext (),
1105- /* prefetchNumStages=*/ useDirectLoad ? 0 : 2 ,
1110+ /* prefetchNumStages=*/ prefetchStages ,
11061111 /* no_reduce_shared_memory_bank_conflicts=*/ useDirectLoad,
11071112 /* use_igemm_convolution=*/ true ,
11081113 /* reorder_workgroups_strategy=*/ std::nullopt );
@@ -1119,9 +1124,11 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
11191124 workgroupSize, targetSubgroupSize, pipelineConfig);
11201125}
11211126
1122- LogicalResult setMatmulLoweringConfig (IREE::GPU::TargetAttr target,
1123- mlir::FunctionOpInterface entryPoint,
1124- Operation *op, bool useDirectLoad) {
1127+ LogicalResult
1128+ setMatmulLoweringConfig (IREE::GPU::TargetAttr target,
1129+ mlir::FunctionOpInterface entryPoint, Operation *op,
1130+ bool useDirectLoad,
1131+ std::optional<uint64_t > prefetchNumStages) {
11251132 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
11261133 if (!linalgOp ||
11271134 (!linalg::isaContractionOpInterface (linalgOp) &&
@@ -1172,9 +1179,11 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
11721179 LoweringConfigAttr loweringConfig = configAndWgSize->first ;
11731180
11741181 SmallVector<NamedAttribute, 1 > pipelineAttrs;
1182+ // Default to 2 stages if not specified.
1183+ int64_t prefetchStages = prefetchNumStages.value_or (2 );
11751184 auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get (
11761185 linalgOp->getContext (),
1177- /* prefetchNumStages=*/ useDirectLoad ? 0 : 2 ,
1186+ /* prefetchNumStages=*/ prefetchStages ,
11781187 /* no_reduce_shared_memory_bank_conflicts=*/ useDirectLoad,
11791188 /* use_igemm_convolution=*/ false ,
11801189 /* reorder_workgroups_strategy=*/ std::nullopt );
@@ -1781,10 +1790,9 @@ LogicalResult setScatterLoweringConfig(IREE::GPU::TargetAttr target,
17811790 {flatWorkgroupSize, 1 , 1 }, flatWorkgroupSize, DictionaryAttr ());
17821791}
17831792
1784- LogicalResult
1785- setDirectConvolutionLoweringConfig (IREE::GPU::TargetAttr target,
1786- mlir::FunctionOpInterface entryPoint,
1787- Operation *op) {
1793+ LogicalResult setDirectConvolutionLoweringConfig (
1794+ IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
1795+ Operation *op, std::optional<uint64_t > prefetchNumStages) {
17881796 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
17891797 if (!linalgOp || !linalg::isaConvolutionOpInterface (linalgOp)) {
17901798 return failure ();
@@ -2029,9 +2037,10 @@ setDirectConvolutionLoweringConfig(IREE::GPU::TargetAttr target,
20292037 auto configDict = DictionaryAttr::get (context, attrs);
20302038 auto loweringConfig = IREE::GPU::LoweringConfigAttr::get (context, configDict);
20312039
2032- // Prefetch shared memory is kept off.
2040+ // By default, prefetch shared memory is kept off.
2041+ int64_t prefetchStages = prefetchNumStages.value_or (0 );
20332042 auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get (
2034- context, /* prefetchNumStages=*/ 0 ,
2043+ context, /* prefetchNumStages=*/ prefetchStages ,
20352044 /* no_reduce_shared_memory_bank_conflicts=*/ false ,
20362045 /* use_igemm_convolution=*/ false ,
20372046 /* reorder_workgroups_strategy=*/ std::nullopt );
0 commit comments