@@ -240,8 +240,8 @@ static GemmCutoff computeGemmCutoffsForAI(IREE::GPU::TargetAttr target,
240240// / problem based on the available mma intrinsics.
241241static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget (
242242 IREE::GPU::TargetAttr target, GPUMatmulShapeType problem,
243- bool transposedLhs, bool transposedRhs, bool mustBeAligned = true ,
244- bool doCPromotion = false , bool scaled = false ) {
243+ bool transposedLhs, bool transposedRhs, bool isGemm ,
244+ bool mustBeAligned = true , bool doCPromotion = false , bool scaled = false ) {
245245 const int64_t targetSubgroupSize = target.getPreferredSubgroupSize ();
246246 SmallVector<GPUIntrinsicType> intrinsics;
247247 if (scaled) {
@@ -307,19 +307,40 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
307307 // bestMNTileCountPerSubgroup and small bestKTileCountPerSubgroup to
308308 // amortize launch/memory costs and maximize throughput.
309309 problem.gemmSize = GemmSize::LargeGemm;
310- seeds = {/* bestSubgroupCountPerWorkgroup=*/ 8 ,
311- /* bestMNTileCountPerSubgroup=*/ 8 ,
312- /* bestKTileCountPerSubgroup=*/ 2 ,
313- /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
314- inBitWidth};
310+ if (isGemm) {
311+ seeds = {/* bestSubgroupCountPerWorkgroup=*/ 4 ,
312+ /* bestMNTileCountPerSubgroup=*/ 16 ,
313+ /* bestKTileCountPerSubgroup=*/ 2 ,
314+ /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
315+ inBitWidth};
316+ } else {
317+ // Favor more subgroups for convolution to help latency hiding from global
318+ // loads.
319+ seeds = {/* bestSubgroupCountPerWorkgroup=*/ 8 ,
320+ /* bestMNTileCountPerSubgroup=*/ 8 ,
321+ /* bestKTileCountPerSubgroup=*/ 2 ,
322+ /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
323+ inBitWidth};
324+ }
315325 } else {
316326 // Choose balanced tile shapes. Empirically, medium-AI workloads can favor
317327 // either small or large tiles depending on kernel details.
318328 problem.gemmSize = GemmSize::MediumGemm;
319- seeds = {/* bestSubgroupCountPerWorkgroup=*/ 8 ,
320- /* bestMNTileCountPerSubgroup=*/ 4 ,
321- /* bestKTileCountPerSubgroup=*/ 4 ,
322- /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
329+ if (isGemm) {
330+ seeds = {/* bestSubgroupCountPerWorkgroup=*/ 4 ,
331+ /* bestMNTileCountPerSubgroup=*/ 8 ,
332+ /* bestKTileCountPerSubgroup=*/ 4 ,
333+ /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits /
334+ inBitWidth};
335+ } else {
336+ // Favor more subgroups for convolution to help latency hiding from global
337+ // loads.
338+ seeds = {/* bestSubgroupCountPerWorkgroup=*/ 8 ,
339+ /* bestMNTileCountPerSubgroup=*/ 4 ,
340+ /* bestKTileCountPerSubgroup=*/ 4 ,
341+ /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits /
342+ inBitWidth};
343+ }
323344 }
324345 int64_t maxSharedMemoryBytes = target.getWgp ().getMaxWorkgroupMemoryBytes ();
325346
@@ -407,7 +428,7 @@ static FailureOr<std::pair<LoweringConfigAttr, int64_t>>
407428getMatmulOrIGEMMLoweringConfigAndWorkgroupSize (
408429 SmallVector<int64_t > bounds, ArrayRef<AffineMap> maps,
409430 ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool useDirectLoad,
410- bool scaled,
431+ bool isGemm, bool scaled,
411432 std::optional<DenseMap<int64_t , AffineExpr>> convToIgemmDimMap =
412433 std::nullopt ,
413434 std::optional<linalg::ConvolutionDimensions> convDims = std::nullopt ) {
@@ -537,7 +558,8 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
537558 bool mustBeAligned = true ;
538559 bool doCPromotion = false ;
539560 std::optional<GPUMMASchedule> schedule = getMmaScheduleFromProblemAndTarget (
540- target, problem, transposedLhs, transposedRhs, /* mustBeAligned*/ true ,
561+ target, problem, transposedLhs, transposedRhs, isGemm,
562+ /* mustBeAligned*/ true ,
541563 /* doCPromotion*/ false , scaled);
542564
543565 // TODO (nirvedhmeshram, qedawkins): The performance with this will be bad if
@@ -549,7 +571,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
549571 mustBeAligned = false ;
550572 doCPromotion = true ;
551573 schedule = getMmaScheduleFromProblemAndTarget (
552- target, problem, transposedLhs, transposedRhs, mustBeAligned,
574+ target, problem, transposedLhs, transposedRhs, isGemm, mustBeAligned,
553575 doCPromotion, scaled);
554576 }
555577
@@ -710,7 +732,7 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
710732 FailureOr<std::pair<LoweringConfigAttr, int64_t >> configAndWgSize =
711733 getMatmulOrIGEMMLoweringConfigAndWorkgroupSize (
712734 bounds, igemmContractionMaps, igemmOperands, target, useDirectLoad,
713- /* scaled*/ false , convToIgemmDimMap, convDims);
735+ /* isGemm= */ false , /* scaled*/ false , convToIgemmDimMap, convDims);
714736 if (failed (configAndWgSize)) {
715737 return failure ();
716738 }
@@ -756,7 +778,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
756778
757779 FailureOr<std::pair<LoweringConfigAttr, int64_t >> configAndWgSize =
758780 getMatmulOrIGEMMLoweringConfigAndWorkgroupSize (
759- bounds, maps, operands, target, useDirectLoad, /* scaled*/ false );
781+ bounds, maps, operands, target, useDirectLoad, /* isGemm=*/ true ,
782+ /* scaled*/ false );
760783
761784 // TODO (muzasyed) : add generalization for scaled and nonscaled versions of
762785 // matmul lowering.
@@ -765,7 +788,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
765788 // conflicts when dealing with scaled matmuls. For now it is disabled.
766789 useDirectLoad = true ;
767790 configAndWgSize = getMatmulOrIGEMMLoweringConfigAndWorkgroupSize (
768- bounds, maps, operands, target, useDirectLoad, /* scaled*/ true );
791+ bounds, maps, operands, target, useDirectLoad, /* isGemm=*/ true ,
792+ /* scaled*/ true );
769793 }
770794
771795 if (failed (configAndWgSize)) {
0 commit comments