@@ -115,6 +115,16 @@ bool isROCmBackend(IREE::GPU::TargetAttr target) {
115115 return target.getArch ().starts_with (" gfx" );
116116}
117117
118+ static bool needsLoweringConfigPropagation (
119+ IREE::Codegen::DispatchLoweringPassPipeline pipeline) {
120+ using Pipeline = IREE::Codegen::DispatchLoweringPassPipeline;
121+ // Pipelines that do not need propagation of lowering config.
122+ Pipeline supportedPipelines[] = {Pipeline::LLVMGPUTileAndFuse,
123+ Pipeline::LLVMGPUVectorDistribute,
124+ Pipeline::LLVMGPUPadAndVectorDistribute};
125+ return !llvm::is_contained (supportedPipelines, pipeline);
126+ }
127+
118128// ====---------------------------------------------------------------------===//
119129// Matmul Configuration Helpers
120130// ====---------------------------------------------------------------------===//
@@ -339,6 +349,7 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target,
339349 schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount , 1 };
340350
341351 SmallVector<int64_t > workgroupTileSizes (op.getNumLoops (), 0 );
352+ SmallVector<int64_t > reductionTileSizes (op.getNumLoops (), 0 );
342353 // Tile all batch dimensions with unit size.
343354 for (int64_t batch : convolutionDims->batch ) {
344355 workgroupTileSizes[batch] = 1 ;
@@ -351,51 +362,58 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target,
351362 workgroupTileSizes[oc] = 1 ;
352363 }
353364 for (int64_t ic : llvm::drop_end (convolutionDims->inputChannel )) {
354- workgroupTileSizes [ic] = 1 ;
365+ reductionTileSizes [ic] = 1 ;
355366 }
356367 // Compute the M/N dimension tile size by multiply subgroup information.
357368 workgroupTileSizes[mDim ] =
358369 schedule->mWarpCount * schedule->mTileCount * schedule->mSize ;
359370 workgroupTileSizes[nDim] =
360371 schedule->nWarpCount * schedule->nTileCount * schedule->nSize ;
361372
362- // Follow the LLVMGPU convention of keeping all of the tile sizes in one list.
363- workgroupTileSizes[kDim ] = schedule->kTileCount * schedule->kSize ;
373+ reductionTileSizes[kDim ] = schedule->kTileCount * schedule->kSize ;
364374
365375 // Tile all filter loop dimensions to 1.
366376 for (int64_t filterDim : convolutionDims->filterLoop ) {
367- workgroupTileSizes [filterDim] = 1 ;
377+ reductionTileSizes [filterDim] = 1 ;
368378 }
369379
370- TileSizesListType tileSizes;
371- tileSizes.push_back (workgroupTileSizes);
380+ MLIRContext *context = op.getContext ();
381+ Builder b (context);
382+ SmallVector<NamedAttribute, 2 > attrs;
383+ attrs.emplace_back (StringAttr::get (context, " workgroup" ),
384+ b.getI64ArrayAttr (workgroupTileSizes));
385+ attrs.emplace_back (StringAttr::get (context, " reduction" ),
386+ b.getI64ArrayAttr (reductionTileSizes));
387+
388+ auto configDict = DictionaryAttr::get (context, attrs);
389+ auto loweringConfig = IREE::GPU::LoweringConfigAttr::get (context, configDict);
372390
373391 // Attach the MMA schedule as an attribute to the entry point export function
374392 // for later access in the pipeline.
375- MLIRContext *context = op. getContext () ;
393+ SmallVector<NamedAttribute, 1 > pipelineAttrs ;
376394 auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get (
377395 context, target.getWgp ().getMma ()[schedule->index ], schedule->mWarpCount ,
378396 schedule->nWarpCount );
379- SmallVector<NamedAttribute, 1 > attrs;
380- attrs. emplace_back ( StringAttr::get (context, " mma_schedule " ), scheduleAttr);
397+ pipelineAttrs. emplace_back ( StringAttr::get (context, " mma_schedule " ),
398+ scheduleAttr);
381399
382400 // Prefetch shared memory if requested.
383401 if (clLLVMGPUEnablePrefetch) {
384402 auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get (
385403 context, /* prefetchSharedMemory=*/ true ,
386404 /* no_reduce_shared_memory_bank_conflicts=*/ false ,
387405 /* reorder_workgroups_strategy=*/ std::nullopt );
388- attrs .emplace_back (
406+ pipelineAttrs .emplace_back (
389407 StringAttr::get (context,
390408 IREE::GPU::GPUPipelineOptionsAttr::getDictKeyName ()),
391409 pipelineOptions);
392410 }
393411
394- auto configDict = DictionaryAttr::get (context, attrs );
412+ auto pipelineConfig = DictionaryAttr::get (context, pipelineAttrs );
395413
396414 return setOpConfigAndEntryPointFnTranslation (
397- entryPoint, op, tileSizes , CodeGenPipeline::LLVMGPUVectorDistribute,
398- workgroupSize, targetSubgroupSize, configDict );
415+ entryPoint, op, loweringConfig , CodeGenPipeline::LLVMGPUVectorDistribute,
416+ workgroupSize, targetSubgroupSize, pipelineConfig );
399417}
400418
401419[[maybe_unused]] static void
@@ -573,6 +591,7 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
573591 schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount , 1 };
574592
575593 SmallVector<int64_t > workgroupTileSizes (op.getNumLoops (), 0 );
594+ SmallVector<int64_t > reductionTileSizes (op.getNumLoops (), 0 );
576595 // Tile all batch dimensions with unit size.
577596 for (int64_t batch : contractionDims->batch ) {
578597 workgroupTileSizes[batch] = 1 ;
@@ -587,7 +606,7 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
587606 workgroupTileSizes[n] = 1 ;
588607 }
589608 for (int64_t k : llvm::drop_end (contractionDims->k )) {
590- workgroupTileSizes [k] = 1 ;
609+ reductionTileSizes [k] = 1 ;
591610 }
592611
593612 // Compute the M/N dimension tile size by multiply subgroup information.
@@ -596,41 +615,50 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
596615 workgroupTileSizes[nDim] =
597616 schedule->nWarpCount * schedule->nTileCount * schedule->nSize ;
598617
599- // Follow the LLVMGPU convention of keeping all of the tile sizes in one list.
600- workgroupTileSizes[kDim ] = schedule->kTileCount * schedule->kSize ;
618+ reductionTileSizes[kDim ] = schedule->kTileCount * schedule->kSize ;
601619
602620 LLVM_DEBUG (debugPrintContractionInfo (" Workgroup tile sizes" , op.getNumLoops (),
603621 *contractionDims, workgroupTileSizes));
622+ LLVM_DEBUG (debugPrintContractionInfo (" Reduction tile sizes" , op.getNumLoops (),
623+ *contractionDims, reductionTileSizes));
604624
605- TileSizesListType tileSizes;
606- tileSizes.push_back (workgroupTileSizes);
625+ MLIRContext *context = op.getContext ();
626+ Builder b (context);
627+ SmallVector<NamedAttribute, 2 > attrs;
628+ attrs.emplace_back (StringAttr::get (context, " workgroup" ),
629+ b.getI64ArrayAttr (workgroupTileSizes));
630+ attrs.emplace_back (StringAttr::get (context, " reduction" ),
631+ b.getI64ArrayAttr (reductionTileSizes));
632+
633+ auto configDict = DictionaryAttr::get (context, attrs);
634+ auto loweringConfig = IREE::GPU::LoweringConfigAttr::get (context, configDict);
607635
608636 // Attach the MMA schedule as an attribute to the entry point export function
609637 // for later access in the pipeline.
610- MLIRContext *context = op. getContext () ;
638+ SmallVector<NamedAttribute, 1 > pipelineAttrs ;
611639 auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get (
612640 context, target.getWgp ().getMma ()[schedule->index ], schedule->mWarpCount ,
613641 schedule->nWarpCount );
614- SmallVector<NamedAttribute, 1 > attrs;
615- attrs. emplace_back ( StringAttr::get (context, " mma_schedule " ), scheduleAttr);
642+ pipelineAttrs. emplace_back ( StringAttr::get (context, " mma_schedule " ),
643+ scheduleAttr);
616644
617645 // Prefetch shared memory if requested.
618646 if (clLLVMGPUEnablePrefetch) {
619647 auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get (
620648 context, /* prefetchSharedMemory=*/ true ,
621649 /* no_reduce_shared_memory_bank_conflicts=*/ false ,
622650 /* reorder_workgroups_strategy=*/ std::nullopt );
623- attrs .emplace_back (
651+ pipelineAttrs .emplace_back (
624652 StringAttr::get (context,
625653 IREE::GPU::GPUPipelineOptionsAttr::getDictKeyName ()),
626654 pipelineOptions);
627655 }
628656
629- auto configDict = DictionaryAttr::get (context, attrs );
657+ auto pipelineConfig = DictionaryAttr::get (context, pipelineAttrs );
630658
631- return setOpConfigAndEntryPointFnTranslation (entryPoint, op, tileSizes,
632- pipeline, workgroupSize,
633- targetSubgroupSize, configDict );
659+ return setOpConfigAndEntryPointFnTranslation (
660+ entryPoint, op, loweringConfig, pipeline, workgroupSize,
661+ targetSubgroupSize, pipelineConfig );
634662}
635663
636664static LogicalResult
@@ -712,8 +740,6 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
712740
713741 LDBG (" Attention Vector Distribution Config" );
714742
715- auto pipeline = CodeGenPipeline::LLVMGPUVectorDistribute;
716-
717743 // Infer if Q, K and V are transposed to help generate better schedule.
718744 bool transposedQ =
719745 k1Dim != llvm::cast<AffineDimExpr>(op.getQueryMap ().getResults ().back ())
@@ -765,6 +791,7 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
765791 schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount , 1 };
766792
767793 SmallVector<int64_t > workgroupTileSizes (opInfo.getDomainRank (), 0 );
794+ SmallVector<int64_t > reductionTileSizes (op.getNumLoops (), 0 );
768795 // Tile all batch dimensions with unit size.
769796 for (int64_t batch : opInfo.getBatchDims ()) {
770797 workgroupTileSizes[batch] = 1 ;
@@ -780,7 +807,7 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
780807 workgroupTileSizes[n] = 1 ;
781808 }
782809 for (int64_t k2 : llvm::drop_end (opInfo.getK2Dims ())) {
783- workgroupTileSizes [k2] = 1 ;
810+ reductionTileSizes [k2] = 1 ;
784811 }
785812
786813 // Compute the M/N dimension tile size by multiply subgroup information.
@@ -789,29 +816,36 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
789816 workgroupTileSizes[nDim] =
790817 schedule->nWarpCount * schedule->nTileCount * schedule->nSize ;
791818
792- // Follow the LLVMGPU convention of keeping all of the tile sizes in one list.
793- workgroupTileSizes[k2Dim] = schedule->kTileCount * schedule->kSize ;
819+ reductionTileSizes[k2Dim] = schedule->kTileCount * schedule->kSize ;
794820
795- TileSizesListType tileSizes;
796- tileSizes.push_back (workgroupTileSizes);
821+ MLIRContext *context = op.getContext ();
822+ SmallVector<NamedAttribute, 2 > attrs;
823+ attrs.emplace_back (StringAttr::get (context, " workgroup" ),
824+ b.getI64ArrayAttr (workgroupTileSizes));
825+ attrs.emplace_back (StringAttr::get (context, " reduction" ),
826+ b.getI64ArrayAttr (reductionTileSizes));
827+
828+ auto configDict = DictionaryAttr::get (context, attrs);
829+ auto loweringConfig = IREE::GPU::LoweringConfigAttr::get (context, configDict);
797830
798831 // Attach the MMA schedule as an attribute to the entry point export function
799832 // for later access in the pipeline.
800- MLIRContext *context = op. getContext () ;
833+ SmallVector<NamedAttribute, 1 > pipelineAttrs ;
801834 auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get (
802835 context, target.getWgp ().getMma ()[schedule->index ], schedule->mWarpCount ,
803836 schedule->nWarpCount );
804- SmallVector<NamedAttribute, 1 > attrs;
805- attrs.emplace_back (StringAttr::get (context, " mma_schedule" ), scheduleAttr);
806- auto configDict = DictionaryAttr::get (context, attrs);
837+ pipelineAttrs.emplace_back (StringAttr::get (context, " mma_schedule" ),
838+ scheduleAttr);
807839
808840 // TODO: We do not turn prefetching on even when requested by the prefetching
809841 // flag because there is a shared memory allocation the two matmuls, which
810842 // the prefetching pass cannot understand.
811843
812- return setOpConfigAndEntryPointFnTranslation (entryPoint, op, tileSizes,
813- pipeline, workgroupSize,
814- targetSubgroupSize, configDict);
844+ auto pipelineConfig = DictionaryAttr::get (context, pipelineAttrs);
845+
846+ return setOpConfigAndEntryPointFnTranslation (
847+ entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute,
848+ workgroupSize, targetSubgroupSize, pipelineConfig);
815849}
816850
817851static LogicalResult
@@ -2108,10 +2142,9 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) {
21082142 SmallVector<Operation *> computeOps = getComputeOps (funcOp);
21092143 if (IREE::Codegen::TranslationInfoAttr translationInfo =
21102144 getTranslationInfo (funcOp)) {
2111- // Currently ROCDL requires propagation of user lowering configs for
2112- // all pipelines except TileAndFuse.
2113- if (translationInfo.getDispatchLoweringPassPipeline () !=
2114- IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
2145+ // Currently some ROCDL requires propagation of user lowering configs.
2146+ if (needsLoweringConfigPropagation (
2147+ translationInfo.getDispatchLoweringPassPipeline ())) {
21152148 for (auto op : computeOps) {
21162149 if (getLoweringConfig (op)) {
21172150 propagateLoweringConfig (op, computeOps);
@@ -2165,10 +2198,9 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) {
21652198
21662199 if (IREE::Codegen::TranslationInfoAttr translationInfo =
21672200 getTranslationInfo (funcOp)) {
2168- // Currently ROCDL requires propagation of user lowering configs for
2169- // all pipelines except TileAndFuse.
2170- if (translationInfo.getDispatchLoweringPassPipeline () ==
2171- IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
2201+ // Currently some ROCDL requires propagation of user lowering configs.
2202+ if (!needsLoweringConfigPropagation (
2203+ translationInfo.getDispatchLoweringPassPipeline ())) {
21722204 return success ();
21732205 }
21742206 }
0 commit comments