@@ -351,16 +351,25 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
351351 seen.insert (map);
352352 }
353353
354- auto isLinear = [](Attribute a ) {
355- return cast<DeviceMappingAttrInterface>(a) .isLinearMapping ();
354+ auto isLinear = [](DeviceMappingAttrInterface attr ) {
355+ return attr .isLinearMapping ();
356356 };
357- if (llvm::any_of (forallOp.getMapping ()-> getValue (), isLinear) &&
358- !llvm::all_of (forallOp.getMapping ()-> getValue (), isLinear)) {
357+ if (llvm::any_of (forallOp.getDeviceMappingAttrs (), isLinear) &&
358+ !llvm::all_of (forallOp.getDeviceMappingAttrs (), isLinear)) {
359359 return definiteFailureHelper (
360360 transformOp, forallOp,
361361 " cannot mix linear and non-linear mapping modes" );
362362 }
363363
364+ FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
365+ forallOp.getDeviceMaskingAttr ();
366+ if (succeeded (maybeMaskingAttr) && *maybeMaskingAttr &&
367+ !forallOp.usesLinearMapping ()) {
368+ return definiteFailureHelper (
369+ transformOp, forallOp,
370+ " device masking is only available in linear mapping mode" );
371+ }
372+
364373 return DiagnosedSilenceableFailure::success ();
365374}
366375
@@ -381,9 +390,7 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
381390 if (forallOp.getNumResults () > 0 )
382391 return definiteFailureHelper (transformOp, forallOp,
383392 " only bufferized scf.forall can be mapped" );
384- bool useLinearMapping = cast<DeviceMappingAttrInterface>(
385- forallOp.getMapping ()->getValue ().front ())
386- .isLinearMapping ();
393+ bool useLinearMapping = forallOp.usesLinearMapping ();
387394 // TODO: This would be more natural with support for Optional<EnumParameter>
388395 // in GPUDeviceMappingAttr.
389396 int64_t maxNumMappingsSupported =
@@ -436,8 +443,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
436443 assert (forallOp.isNormalized () && numParallelIterations.has_value () &&
437444 " requires statically sized, normalized forall op" );
438445 SmallVector<int64_t > tmpMappingSizes = numParallelIterations.value ();
446+ SmallVector<DeviceMappingAttrInterface> forallMappingAttrsVec =
447+ forallOp.getDeviceMappingAttrs ();
439448 SetVector<Attribute> forallMappingAttrs;
440- forallMappingAttrs.insert_range (forallOp. getMapping ()-> getValue () );
449+ forallMappingAttrs.insert_range (forallMappingAttrsVec );
441450 auto comparator = [](Attribute a, Attribute b) -> bool {
442451 return cast<DeviceMappingAttrInterface>(a).getMappingId () <
443452 cast<DeviceMappingAttrInterface>(b).getMappingId ();
@@ -682,12 +691,17 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
682691
683692 // The BlockIdBuilder adapts to whatever is thrown at it.
684693 bool useLinearMapping = false ;
685- if (topLevelForallOp.getMapping ()) {
686- auto mappingAttr = cast<DeviceMappingAttrInterface>(
687- topLevelForallOp.getMapping ()->getValue ().front ());
688- useLinearMapping = mappingAttr.isLinearMapping ();
689- }
690- GpuBlockIdBuilder gpuBlockIdBuilder (getContext (), useLinearMapping);
694+ if (topLevelForallOp.getMapping ())
695+ useLinearMapping = topLevelForallOp.usesLinearMapping ();
696+
697+ FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
698+ topLevelForallOp.getDeviceMaskingAttr ();
699+ assert (succeeded (maybeMaskingAttr) && " unexpected failed maybeMaskingAttr" );
700+ assert ((!*maybeMaskingAttr || useLinearMapping) &&
701+ " masking requires linear mapping" );
702+
703+ GpuBlockIdBuilder gpuBlockIdBuilder (getContext (), useLinearMapping,
704+ *maybeMaskingAttr);
691705
692706 diag = mlir::transform::gpu::mapForallToBlocksImpl (
693707 rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
@@ -744,8 +758,7 @@ static DiagnosedSilenceableFailure
744758getThreadIdBuilder (std::optional<TransformOpInterface> transformOp,
745759 scf::ForallOp forallOp, ArrayRef<int64_t > blockSizes,
746760 int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
747- auto mappingAttr = cast<DeviceMappingAttrInterface>(
748- forallOp.getMapping ()->getValue ().front ());
761+ auto mappingAttr = forallOp.getDeviceMappingAttrs ().front ();
749762 bool useLinearMapping = mappingAttr.isLinearMapping ();
750763
751764 // Sanity checks that may result in runtime verification errors.
@@ -768,21 +781,30 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
768781 if (!diag.succeeded ())
769782 return diag;
770783
784+ FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
785+ forallOp.getDeviceMaskingAttr ();
786+ assert (succeeded (maybeMaskingAttr) && " unexpected failed maybeMaskingAttr" );
787+ assert ((!*maybeMaskingAttr || useLinearMapping) &&
788+ " masking requires linear mapping" );
789+
771790 // Start mapping.
772791 MLIRContext *ctx = forallOp.getContext ();
773792 gpuIdBuilder =
774793 TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
775794 .Case ([&](GPUWarpgroupMappingAttr) {
776- return GpuWarpgroupIdBuilder (ctx, warpSize, useLinearMapping);
795+ return GpuWarpgroupIdBuilder (ctx, warpSize, useLinearMapping,
796+ *maybeMaskingAttr);
777797 })
778798 .Case ([&](GPUWarpMappingAttr) {
779- return GpuWarpIdBuilder (ctx, warpSize, useLinearMapping);
799+ return GpuWarpIdBuilder (ctx, warpSize, useLinearMapping,
800+ *maybeMaskingAttr);
780801 })
781802 .Case ([&](GPUThreadMappingAttr) {
782- return GpuThreadIdBuilder (ctx, useLinearMapping);
803+ return GpuThreadIdBuilder (ctx, useLinearMapping, *maybeMaskingAttr );
783804 })
784805 .Case ([&](GPULaneMappingAttr) {
785- return GpuLaneIdBuilder (ctx, warpSize, useLinearMapping);
806+ return GpuLaneIdBuilder (ctx, warpSize, useLinearMapping,
807+ *maybeMaskingAttr);
786808 })
787809 .Default ([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
788810 llvm_unreachable (" unknown mapping attribute" );
0 commit comments