Skip to content

Commit 85aa5f8

Browse files
[mlir][SCF][GPU] Add DeviceMaskingAttrInterface support to scf::ForallOp and use it to implement warp specialization.
This revision adds DeviceMaskingAttrInterface and extends DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface and DeviceMaskingAttrInterface. The first implementation is if the form of a GPUMappingMaskAttr, which can be additionally passed to the scf.forall.mapping attribute to specify a mask on compute resources that should be active. Support is added to GPUTransformOps to take advantage of this information and lower to block/warpgroup/warp/thread specialization when mapped to linear ids. Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
1 parent c88aee7 commit 85aa5f8

File tree

12 files changed

+444
-59
lines changed

12 files changed

+444
-59
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,24 @@ def GPULaneMappingAttr
252252
}];
253253
}
254254

255+
def GPUMappingMaskAttr : GPU_Attr<"GPUMappingMask", "mask", [
256+
DeclareAttrInterfaceMethods<DeviceMaskingAttrInterface> ] > {
257+
let parameters = (ins "uint64_t":$mask);
258+
let assemblyFormat = "`<` params `>`";
259+
let description = [{
260+
Attribute describing how to filter the processing units that a
261+
region is mapped to.
262+
263+
In the first implementation the masking is a bitfield that specifies for
264+
each processing unit whether it is active or not.
265+
266+
In the future, we may want to implement this as a symbol to refer to
267+
dynamically defined values.
268+
269+
Extending op semantics with an operand is deemed too intrusive at this time.
270+
}];
271+
}
272+
255273
def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
256274
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
257275
let parameters = (ins

mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ struct GpuIdBuilder {
7878
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
7979
/// used for indexing rewrites as well as 1D sizes for predicate generation.
8080
struct GpuBlockIdBuilder : public GpuIdBuilder {
81-
GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
81+
GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
82+
DeviceMaskingAttrInterface mask = nullptr);
8283
};
8384

8485
/// Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
@@ -88,7 +89,8 @@ struct GpuBlockIdBuilder : public GpuIdBuilder {
8889
/// used for indexing rewrites as well as 1D sizes for predicate generation.
8990
struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
9091
GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
91-
bool useLinearMapping = false);
92+
bool useLinearMapping = false,
93+
DeviceMaskingAttrInterface mask = nullptr);
9294
int64_t warpSize = 32;
9395
/// In the future this may be configured by the transformation.
9496
static constexpr int64_t kNumWarpsPerGroup = 4;
@@ -101,7 +103,8 @@ struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
101103
/// used for indexing rewrites as well as 1D sizes for predicate generation.
102104
struct GpuWarpIdBuilder : public GpuIdBuilder {
103105
GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
104-
bool useLinearMapping = false);
106+
bool useLinearMapping = false,
107+
DeviceMaskingAttrInterface mask = nullptr);
105108
int64_t warpSize = 32;
106109
};
107110

@@ -111,15 +114,17 @@ struct GpuWarpIdBuilder : public GpuIdBuilder {
111114
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
112115
/// used for indexing rewrites as well as 1D sizes for predicate generation.
113116
struct GpuThreadIdBuilder : public GpuIdBuilder {
114-
GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
117+
GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
118+
DeviceMaskingAttrInterface mask = nullptr);
115119
};
116120

117121
/// Builder for lane id.
118122
/// The `idBuilder` method returns nD values used for indexing rewrites as well
119123
/// as 1D sizes for predicate generation.
120124
/// This `useLinearMapping` case is the only supported case.
121125
struct GpuLaneIdBuilder : public GpuIdBuilder {
122-
GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused);
126+
GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused,
127+
DeviceMaskingAttrInterface mask = nullptr);
123128
int64_t warpSize = 32;
124129
};
125130

mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,51 @@ def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
6060
];
6161
}
6262

63+
def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> {
64+
let cppNamespace = "::mlir";
65+
let description = [{
66+
Attribute interface describing how to filter the processing units that a
67+
region is mapped to.
68+
69+
A popcount can be applied to determine the logical linear index that a
70+
physical processing unit is responsible for.
71+
}];
72+
73+
let methods = [
74+
InterfaceMethod<
75+
/*desc=*/[{
76+
Return the logical active id for a given physical id.
77+
Expects a physicalLinearMappingId of I64Type.
78+
}],
79+
/*retTy=*/"Value",
80+
/*methodName=*/"getLogicalLinearMappingId",
81+
/*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
82+
>,
83+
InterfaceMethod<
84+
/*desc=*/[{
85+
Return the dynamic condition determining whether a given physical id is
86+
active under the mask.
87+
Expects a physicalLinearMappingId of I64Type.
88+
}],
89+
/*retTy=*/"Value",
90+
/*methodName=*/"getIsActiveIdPredicate",
91+
/*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
92+
>,
93+
InterfaceMethod<
94+
/*desc=*/[{
95+
Return the maximal number of pysical ids supported.
96+
This is to account for temporary implementation limitations (e.g. i64)
97+
and fail gracefully with actionnable error messages.
98+
}],
99+
/*retTy=*/"int64_t",
100+
/*methodName=*/"getMaxNumPhysicalIds",
101+
/*args=*/(ins)
102+
>,
103+
];
104+
}
105+
63106
def DeviceMappingArrayAttr :
64-
TypedArrayAttrBase<DeviceMappingAttrInterface,
107+
TypedArrayAttrBase<AnyAttrOf<[DeviceMappingAttrInterface, DeviceMaskingAttrInterface]>,
65108
"Device Mapping array attribute"> { }
66109

67110
#endif // MLIR_DEVICEMAPPINGINTERFACE

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,18 @@ def ForallOp : SCF_Op<"forall", [
611611
/// Returns operations within scf.forall.in_parallel whose destination
612612
/// operand is the block argument `bbArg`.
613613
SmallVector<Operation*> getCombiningOps(BlockArgument bbArg);
614+
615+
/// Returns the subset of DeviceMappingArrayAttrs of type
616+
/// DeviceMappingAttrInterface.
617+
SmallVector<DeviceMappingAttrInterface> getDeviceMappingAttrs();
618+
619+
/// Returns the at most one DeviceMaskingAttrInterface in the mapping.
620+
/// If more than one DeviceMaskingAttrInterface is specified, returns
621+
/// failure. If no mapping is present, returns nullptr.
622+
FailureOr<DeviceMaskingAttrInterface> getDeviceMaskingAttr();
623+
624+
/// Returns true if the mapping specified for this forall op is linear.
625+
bool usesLinearMapping();
614626
}];
615627
}
616628

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRGPUDialect
2020
MLIRFunctionInterfaces
2121
MLIRInferIntRangeInterface
2222
MLIRIR
23+
MLIRMathDialect
2324
MLIRMemRefDialect
2425
MLIRSideEffectInterfaces
2526
MLIRSupport

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
17+
#include "mlir/Dialect/Math/IR/Math.h"
1718
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1819
#include "mlir/IR/Attributes.h"
1920
#include "mlir/IR/Builders.h"
@@ -120,6 +121,50 @@ int64_t GPULaneMappingAttr::getRelativeIndex() const {
120121
: getMappingId();
121122
}
122123

124+
int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
125+
126+
/// 8 4 0
127+
/// Example mask : 0 0 0 1 1 0 1 0 0
128+
///
129+
/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
130+
/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
131+
///
132+
/// Example mask : 0 0 0 1 1 0 1 0 0
133+
/// Example filter: 0 0 0 0 1 1 1 1 1
134+
/// Intersection : 0 0 0 0 1 0 1 0 0
135+
/// PopCnt : 2
136+
Value GPUMappingMaskAttr::getLogicalLinearMappingId(
137+
OpBuilder &b, Value physicalLinearMappingId) const {
138+
Location loc = physicalLinearMappingId.getLoc();
139+
Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
140+
Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
141+
Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
142+
filter = b.create<arith::SubIOp>(loc, filter, one);
143+
Value filteredId = b.create<arith::AndIOp>(loc, mask, filter);
144+
return b.create<math::CtPopOp>(loc, filteredId);
145+
}
146+
147+
/// 8 4 0
148+
/// Example mask : 0 0 0 1 1 0 1 0 0
149+
///
150+
/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
151+
/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
152+
///
153+
/// Example mask : 0 0 0 1 1 0 1 0 0
154+
/// Example filter: 0 0 0 1 0 0 0 0 0
155+
/// Intersection : 0 0 0 1 0 0 0 0 0
156+
/// Cmp : 1
157+
Value GPUMappingMaskAttr::getIsActiveIdPredicate(
158+
OpBuilder &b, Value physicalLinearMappingId) const {
159+
Location loc = physicalLinearMappingId.getLoc();
160+
Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
161+
Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
162+
Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
163+
Value filtered = b.create<arith::AndIOp>(loc, mask, filter);
164+
Value zero = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
165+
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, filtered, zero);
166+
}
167+
123168
int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
124169
return static_cast<int64_t>(getAddressSpace());
125170
}

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,25 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
351351
seen.insert(map);
352352
}
353353

354-
auto isLinear = [](Attribute a) {
355-
return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
354+
auto isLinear = [](DeviceMappingAttrInterface attr) {
355+
return attr.isLinearMapping();
356356
};
357-
if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
358-
!llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
357+
if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
358+
!llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
359359
return definiteFailureHelper(
360360
transformOp, forallOp,
361361
"cannot mix linear and non-linear mapping modes");
362362
}
363363

364+
FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
365+
forallOp.getDeviceMaskingAttr();
366+
if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr &&
367+
!forallOp.usesLinearMapping()) {
368+
return definiteFailureHelper(
369+
transformOp, forallOp,
370+
"device masking is only available in linear mapping mode");
371+
}
372+
364373
return DiagnosedSilenceableFailure::success();
365374
}
366375

@@ -381,9 +390,7 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
381390
if (forallOp.getNumResults() > 0)
382391
return definiteFailureHelper(transformOp, forallOp,
383392
"only bufferized scf.forall can be mapped");
384-
bool useLinearMapping = cast<DeviceMappingAttrInterface>(
385-
forallOp.getMapping()->getValue().front())
386-
.isLinearMapping();
393+
bool useLinearMapping = forallOp.usesLinearMapping();
387394
// TODO: This would be more natural with support for Optional<EnumParameter>
388395
// in GPUDeviceMappingAttr.
389396
int64_t maxNumMappingsSupported =
@@ -436,8 +443,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
436443
assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
437444
"requires statically sized, normalized forall op");
438445
SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
446+
SmallVector<DeviceMappingAttrInterface> forallMappingAttrsVec =
447+
forallOp.getDeviceMappingAttrs();
439448
SetVector<Attribute> forallMappingAttrs;
440-
forallMappingAttrs.insert_range(forallOp.getMapping()->getValue());
449+
forallMappingAttrs.insert_range(forallMappingAttrsVec);
441450
auto comparator = [](Attribute a, Attribute b) -> bool {
442451
return cast<DeviceMappingAttrInterface>(a).getMappingId() <
443452
cast<DeviceMappingAttrInterface>(b).getMappingId();
@@ -682,12 +691,17 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
682691

683692
// The BlockIdBuilder adapts to whatever is thrown at it.
684693
bool useLinearMapping = false;
685-
if (topLevelForallOp.getMapping()) {
686-
auto mappingAttr = cast<DeviceMappingAttrInterface>(
687-
topLevelForallOp.getMapping()->getValue().front());
688-
useLinearMapping = mappingAttr.isLinearMapping();
689-
}
690-
GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
694+
if (topLevelForallOp.getMapping())
695+
useLinearMapping = topLevelForallOp.usesLinearMapping();
696+
697+
FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
698+
topLevelForallOp.getDeviceMaskingAttr();
699+
assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
700+
assert((!*maybeMaskingAttr || useLinearMapping) &&
701+
"masking requires linear mapping");
702+
703+
GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping,
704+
*maybeMaskingAttr);
691705

692706
diag = mlir::transform::gpu::mapForallToBlocksImpl(
693707
rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
@@ -744,8 +758,7 @@ static DiagnosedSilenceableFailure
744758
getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
745759
scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
746760
int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
747-
auto mappingAttr = cast<DeviceMappingAttrInterface>(
748-
forallOp.getMapping()->getValue().front());
761+
auto mappingAttr = forallOp.getDeviceMappingAttrs().front();
749762
bool useLinearMapping = mappingAttr.isLinearMapping();
750763

751764
// Sanity checks that may result in runtime verification errors.
@@ -768,21 +781,30 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
768781
if (!diag.succeeded())
769782
return diag;
770783

784+
FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
785+
forallOp.getDeviceMaskingAttr();
786+
assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
787+
assert((!*maybeMaskingAttr || useLinearMapping) &&
788+
"masking requires linear mapping");
789+
771790
// Start mapping.
772791
MLIRContext *ctx = forallOp.getContext();
773792
gpuIdBuilder =
774793
TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
775794
.Case([&](GPUWarpgroupMappingAttr) {
776-
return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
795+
return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping,
796+
*maybeMaskingAttr);
777797
})
778798
.Case([&](GPUWarpMappingAttr) {
779-
return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
799+
return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping,
800+
*maybeMaskingAttr);
780801
})
781802
.Case([&](GPUThreadMappingAttr) {
782-
return GpuThreadIdBuilder(ctx, useLinearMapping);
803+
return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr);
783804
})
784805
.Case([&](GPULaneMappingAttr) {
785-
return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping);
806+
return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
807+
*maybeMaskingAttr);
786808
})
787809
.Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
788810
llvm_unreachable("unknown mapping attribute");

0 commit comments

Comments
 (0)