1212#include " mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
1313#include " mlir/Conversion/LLVMCommon/TypeConverter.h"
1414#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15- #include " mlir/Dialect/Affine/IR/AffineOps.h"
1615#include " mlir/Dialect/Arith/IR/Arith.h"
17- #include " mlir/Dialect/Func/IR/FuncOps.h"
1816#include " mlir/Dialect/GPU/IR/GPUDialect.h"
1917#include " mlir/Dialect/GPU/TransformOps/Utils.h"
2018#include " mlir/Dialect/GPU/Transforms/Passes.h"
@@ -351,16 +349,25 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
351349 seen.insert (map);
352350 }
353351
354- auto isLinear = [](Attribute a ) {
355- return cast<DeviceMappingAttrInterface>(a) .isLinearMapping ();
352+ auto isLinear = [](DeviceMappingAttrInterface attr ) {
353+ return attr .isLinearMapping ();
356354 };
357- if (llvm::any_of (forallOp.getMapping ()-> getValue (), isLinear) &&
358- !llvm::all_of (forallOp.getMapping ()-> getValue (), isLinear)) {
355+ if (llvm::any_of (forallOp.getDeviceMappingAttrs (), isLinear) &&
356+ !llvm::all_of (forallOp.getDeviceMappingAttrs (), isLinear)) {
359357 return definiteFailureHelper (
360358 transformOp, forallOp,
361359 " cannot mix linear and non-linear mapping modes" );
362360 }
363361
362+ FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
363+ forallOp.getDeviceMaskingAttr ();
364+ if (succeeded (maybeMaskingAttr) && *maybeMaskingAttr &&
365+ !forallOp.usesLinearMapping ()) {
366+ return definiteFailureHelper (
367+ transformOp, forallOp,
368+ " device masking is only available in linear mapping mode" );
369+ }
370+
364371 return DiagnosedSilenceableFailure::success ();
365372}
366373
@@ -381,9 +388,7 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
381388 if (forallOp.getNumResults () > 0 )
382389 return definiteFailureHelper (transformOp, forallOp,
383390 " only bufferized scf.forall can be mapped" );
384- bool useLinearMapping = cast<DeviceMappingAttrInterface>(
385- forallOp.getMapping ()->getValue ().front ())
386- .isLinearMapping ();
391+ bool useLinearMapping = forallOp.usesLinearMapping ();
387392 // TODO: This would be more natural with support for Optional<EnumParameter>
388393 // in GPUDeviceMappingAttr.
389394 int64_t maxNumMappingsSupported =
@@ -436,8 +441,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
436441 assert (forallOp.isNormalized () && numParallelIterations.has_value () &&
437442 " requires statically sized, normalized forall op" );
438443 SmallVector<int64_t > tmpMappingSizes = numParallelIterations.value ();
444+ SmallVector<DeviceMappingAttrInterface> forallMappingAttrsVec =
445+ forallOp.getDeviceMappingAttrs ();
439446 SetVector<Attribute> forallMappingAttrs;
440- forallMappingAttrs.insert_range (forallOp. getMapping ()-> getValue () );
447+ forallMappingAttrs.insert_range (forallMappingAttrsVec );
441448 auto comparator = [](Attribute a, Attribute b) -> bool {
442449 return cast<DeviceMappingAttrInterface>(a).getMappingId () <
443450 cast<DeviceMappingAttrInterface>(b).getMappingId ();
@@ -682,12 +689,17 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
682689
683690 // The BlockIdBuilder adapts to whatever is thrown at it.
684691 bool useLinearMapping = false ;
685- if (topLevelForallOp.getMapping ()) {
686- auto mappingAttr = cast<DeviceMappingAttrInterface>(
687- topLevelForallOp.getMapping ()->getValue ().front ());
688- useLinearMapping = mappingAttr.isLinearMapping ();
689- }
690- GpuBlockIdBuilder gpuBlockIdBuilder (getContext (), useLinearMapping);
692+ if (topLevelForallOp.getMapping ())
693+ useLinearMapping = topLevelForallOp.usesLinearMapping ();
694+
695+ FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
696+ topLevelForallOp.getDeviceMaskingAttr ();
697+ assert (succeeded (maybeMaskingAttr) && " unexpected failed maybeMaskingAttr" );
698+ assert ((!*maybeMaskingAttr || useLinearMapping) &&
699+ " masking requires linear mapping" );
700+
701+ GpuBlockIdBuilder gpuBlockIdBuilder (getContext (), useLinearMapping,
702+ *maybeMaskingAttr);
691703
692704 diag = mlir::transform::gpu::mapForallToBlocksImpl (
693705 rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
@@ -744,8 +756,8 @@ static DiagnosedSilenceableFailure
744756getThreadIdBuilder (std::optional<TransformOpInterface> transformOp,
745757 scf::ForallOp forallOp, ArrayRef<int64_t > blockSizes,
746758 int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
747- auto mappingAttr = cast<DeviceMappingAttrInterface>(
748- forallOp.getMapping ()-> getValue () .front () );
759+ DeviceMappingAttrInterface mappingAttr =
760+ forallOp.getDeviceMappingAttrs () .front ();
749761 bool useLinearMapping = mappingAttr.isLinearMapping ();
750762
751763 // Sanity checks that may result in runtime verification errors.
@@ -768,21 +780,30 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
768780 if (!diag.succeeded ())
769781 return diag;
770782
783+ FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
784+ forallOp.getDeviceMaskingAttr ();
785+ assert (succeeded (maybeMaskingAttr) && " unexpected failed maybeMaskingAttr" );
786+ assert ((!*maybeMaskingAttr || useLinearMapping) &&
787+ " masking requires linear mapping" );
788+
771789 // Start mapping.
772790 MLIRContext *ctx = forallOp.getContext ();
773791 gpuIdBuilder =
774792 TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
775793 .Case ([&](GPUWarpgroupMappingAttr) {
776- return GpuWarpgroupIdBuilder (ctx, warpSize, useLinearMapping);
794+ return GpuWarpgroupIdBuilder (ctx, warpSize, useLinearMapping,
795+ *maybeMaskingAttr);
777796 })
778797 .Case ([&](GPUWarpMappingAttr) {
779- return GpuWarpIdBuilder (ctx, warpSize, useLinearMapping);
798+ return GpuWarpIdBuilder (ctx, warpSize, useLinearMapping,
799+ *maybeMaskingAttr);
780800 })
781801 .Case ([&](GPUThreadMappingAttr) {
782- return GpuThreadIdBuilder (ctx, useLinearMapping);
802+ return GpuThreadIdBuilder (ctx, useLinearMapping, *maybeMaskingAttr );
783803 })
784804 .Case ([&](GPULaneMappingAttr) {
785- return GpuLaneIdBuilder (ctx, warpSize, useLinearMapping);
805+ return GpuLaneIdBuilder (ctx, warpSize, useLinearMapping,
806+ *maybeMaskingAttr);
786807 })
787808 .Default ([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
788809 llvm_unreachable (" unknown mapping attribute" );
0 commit comments