Skip to content

Commit 246761e

Browse files
committed
Improve verification
1 parent 5965b54 commit 246761e

File tree

8 files changed

+77
-47
lines changed

8 files changed

+77
-47
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,12 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
228228
"FailureOr<SmallVector<Value>>",
229229
"delinearizeId",
230230
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
231-
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for dist units
231+
InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
232232
assigned to a level identified by linearId. The shape parameter
233233
represents the higher-level problem size. Each level may access
234234
multiple blocks according to round-robin distribution rules.}],
235235
"FailureOr<SmallVector<SmallVector<Value>>>",
236-
"computeDistributedOffsets",
236+
"computeDistributedCoords",
237237
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
238238
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
239239
to some other layout according to given permutation of (0...n-1).}],
@@ -481,12 +481,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
481481
FailureOr<SmallVector<Value>>
482482
delinearizeId(OpBuilder &builder, Location loc, Value linearId);
483483

484-
/// Generates instructions to compute multidimensional offsets for dist units
484+
/// Generates instructions to compute multidimensional coordinates for dist units
485485
/// assigned to a level identified by linearId. The shape parameter
486486
/// represents the higher-level problem size. Each `level` may access
487487
/// multiple blocks according to round-robin distribution rules.
488488
FailureOr<SmallVector<SmallVector<Value>>>
489-
computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
489+
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
490490

491491
/// Check if this is slice of some other layout.
492492
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -645,13 +645,13 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
645645
FailureOr<SmallVector<Value>>
646646
delinearizeId(OpBuilder &builder, Location loc, Value linearId);
647647

648-
/// Generates instructions to compute multidimensional offsets for blocks
648+
/// Generates instructions to compute multidimensional coordinates for blocks
649649
/// assigned to a subgroup identified by linearId. The shape parameter
650650
/// represents the workgroup-level problem size. Each subgroup may access
651651
/// multiple blocks according to round-robin distribution rules.
652652

653653
FailureOr<SmallVector<SmallVector<Value>>>
654-
computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape);
654+
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape);
655655

656656
/// Check if this is slice of some other layout.
657657
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ void XeGPUDialect::initialize() {
4747
// `delinearizedId` is used to identify a 16x32 of a subgroup in each
4848
// distribution unit.
4949
static SmallVector<SmallVector<Value>>
50-
genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
51-
ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
52-
ArrayRef<int64_t> srcShape) {
53-
SmallVector<SmallVector<Value>> offsets;
50+
genCoordinates(OpBuilder &builder, Location loc,
51+
SmallVector<Value> delinearizedId,
52+
ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
53+
ArrayRef<int64_t> srcShape) {
54+
SmallVector<SmallVector<Value>> coordinates;
5455

5556
// A distribution unit must be less than or equal to `srcShape`
5657
SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
@@ -89,9 +90,9 @@ genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
8990
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
9091
});
9192

92-
offsets.push_back(mods);
93+
coordinates.push_back(mods);
9394
}
94-
return offsets;
95+
return coordinates;
9596
}
9697

9798
// Checks if the given shape can be evenly distributed based on the layout
@@ -298,12 +299,12 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
298299
return affine::delinearizeIndex(builder, loc, linearId, dims);
299300
}
300301

301-
/// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
302+
/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
302303
/// instructions for computing multi-dimensional offsets when distributed by
303304
/// LayoutAttr.
304305
FailureOr<SmallVector<SmallVector<Value>>>
305-
LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
306-
Value linearId, ArrayRef<int64_t> shape) {
306+
LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
307+
Value linearId, ArrayRef<int64_t> shape) {
307308
SmallVector<int64_t> layout;
308309
SmallVector<int64_t> subShape;
309310
if (isForWorkgroup()) {
@@ -328,7 +329,7 @@ LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
328329
return failure();
329330
SmallVector<Value> ids = *maybeIds;
330331

331-
return genOffsets(builder, loc, ids, layout, subShape, shape);
332+
return genCoordinates(builder, loc, ids, layout, subShape, shape);
332333
}
333334

334335
//===----------------------------------------------------------------------===//
@@ -388,12 +389,12 @@ SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
388389
return parent.delinearizeId(builder, loc, linearId);
389390
}
390391

391-
// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
392+
// Implements DistributeLayoutAttr::computeDistributedCoords to generate
392393
// instructions for computing multi-dimensional offsets when distributed by
393394
// LayoutAttr.
394395
FailureOr<SmallVector<SmallVector<Value>>>
395-
SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
396-
Value linearId, ArrayRef<int64_t> shape) {
396+
SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
397+
Value linearId, ArrayRef<int64_t> shape) {
397398
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
398399
if (!isForWorkgroup())
399400
return failure();
@@ -428,7 +429,7 @@ SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
428429
SmallVector<Value> sgIds =
429430
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
430431

431-
return genOffsets(builder, loc, sgIds, layout, subShape, shape);
432+
return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
432433
}
433434

434435
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
175175

176176
LogicalResult
177177
IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
178-
UnitAttr subgroup_block_io,
178+
UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
179179
function_ref<InFlightDiagnostic()> emitError) {
180180

181181
if (!dataTy) {
@@ -191,7 +191,20 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
191191

192192
ArrayRef<int64_t> dataShape = dataTy.getShape();
193193
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
194-
194+
if (subgroup_block_io && layout) {
195+
auto laneData = layout.getEffectiveLaneDataAsInt();
196+
if (!laneData.empty()) {
197+
bool isLaneDataLinear =
198+
std::all_of(laneData.begin(), std::prev(laneData.end()),
199+
[](int x) { return x == 1; });
200+
if (!isLaneDataLinear)
201+
return emitError()
202+
<< "With subgroup_block_io, lane data must be linear.";
203+
if (isLaneDataLinear && laneData.back() != 1)
204+
return emitError()
205+
<< "With subgroup_block_io, lane data must be coalesced.";
206+
}
207+
}
195208
if (dataShape.size() == 2) {
196209
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
197210
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
@@ -1102,7 +1115,7 @@ LogicalResult LoadMatrixOp::verify() {
11021115
MemDescType mdescTy = getMemDesc().getType();
11031116

11041117
return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
1105-
[&]() { return emitError(); });
1118+
getLayoutAttr(), [&]() { return emitError(); });
11061119
}
11071120

11081121
//===----------------------------------------------------------------------===//
@@ -1126,7 +1139,7 @@ LogicalResult StoreMatrixOp::verify() {
11261139
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
11271140
MemDescType mdescTy = getMemDesc().getType();
11281141
return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
1129-
[&]() { return emitError(); });
1142+
getLayoutAttr(), [&]() { return emitError(); });
11301143
}
11311144

11321145
namespace mlir {

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -907,22 +907,22 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
907907
}
908908
};
909909

910-
static SmallVector<Value> computeDistributedOffsetsForMatrixOp(
910+
static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
911911
PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
912912
Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
913-
SmallVector<Value> newOffsets;
914-
auto maybeDescOffsets =
915-
layout.computeDistributedOffsets(rewriter, loc, laneId, payloadShape);
916-
if (failed(maybeDescOffsets))
913+
SmallVector<Value> newCoods;
914+
auto maybeCoords =
915+
layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
916+
if (failed(maybeCoords))
917917
return {};
918-
assert(maybeDescOffsets.value().size() == 1 &&
918+
assert(maybeCoords.value().size() == 1 &&
919919
"Expected one set of distributed offsets");
920920
SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
921-
rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
921+
rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
922922
getAsOpFoldResult(origOffsets));
923-
newOffsets = llvm::to_vector(llvm::map_range(
923+
newCoods = llvm::to_vector(llvm::map_range(
924924
ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
925-
return newOffsets;
925+
return newCoods;
926926
}
927927

928928
/// Pattern for distributing xegpu::LoadMatrixOp.
@@ -969,7 +969,7 @@ struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
969969
getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
970970
if (failed(distPayloadByWarpOpOrFailure))
971971
return rewriter.notifyMatchFailure(
972-
matrixOp, "The matrix op payload has no layout.");
972+
matrixOp, "Failed to distribute matrix op payload based on layout.");
973973

974974
SmallVector<Value> operands = {matrixOp.getMemDesc()};
975975
const unsigned offsetsStartIdx = operands.size();
@@ -992,17 +992,17 @@ struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
992992
ValueRange currentOffsets =
993993
ValueRange(newOperands).drop_front(offsetsStartIdx);
994994

995-
SmallVector<Value> newOffsets = currentOffsets;
995+
SmallVector<Value> newCoords = currentOffsets;
996996
rewriter.setInsertionPointAfter(newWarpOp);
997997

998998
if (!matrixOp.getSubgroupBlockIoAttr()) {
999-
newOffsets = computeDistributedOffsetsForMatrixOp(
999+
newCoords = computeDistributedCoordinatesForMatrixOp(
10001000
rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
10011001
currentOffsets);
10021002
}
10031003
xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
10041004
rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
1005-
newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
1005+
newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
10061006
matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
10071007
// Resolve the output type and replace all uses.
10081008
rewriter.replaceAllUsesWith(
@@ -1045,7 +1045,7 @@ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
10451045
getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
10461046
if (failed(distPayloadByWarpOpOrFailure))
10471047
return rewriter.notifyMatchFailure(
1048-
matrixOp, "The matrix op payload has no layout.");
1048+
matrixOp, "Failed to distribute matrix op payload based on layout.");
10491049

10501050
SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
10511051
const unsigned offsetsStartIdx = operands.size();
@@ -1069,18 +1069,18 @@ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
10691069
ValueRange currentOffsets =
10701070
ValueRange(newOperands).drop_front(offsetsStartIdx);
10711071

1072-
SmallVector<Value> newOffsets = currentOffsets;
1072+
SmallVector<Value> newCoords = currentOffsets;
10731073
rewriter.setInsertionPointAfter(newWarpOp);
10741074

10751075
if (!matrixOp.getSubgroupBlockIoAttr()) {
1076-
newOffsets = computeDistributedOffsetsForMatrixOp(
1076+
newCoords = computeDistributedCoordinatesForMatrixOp(
10771077
rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
10781078
currentOffsets);
10791079
}
10801080

10811081
xegpu::StoreMatrixOp::create(
10821082
rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1083-
ValueRange(newOffsets), newConstOffsetsAttr,
1083+
ValueRange(newCoords), newConstOffsetsAttr,
10841084
matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
10851085
rewriter.eraseOp(matrixOp);
10861086
return success();

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
115115
// descriptors to be accessed, based on the layout information.
116116
ArrayRef<int64_t> wgShape = op.getDataShape();
117117
auto maybeDescOffsets =
118-
layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
118+
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
119119
if (failed(maybeDescOffsets))
120120
return failure();
121121

@@ -832,7 +832,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
832832
Value sgId =
833833
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
834834
auto sgOffsets =
835-
layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
835+
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
836836
if (failed(sgOffsets))
837837
return failure();
838838

@@ -1054,7 +1054,7 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
10541054
Value sgId =
10551055
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
10561056
auto sgOffsets =
1057-
layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
1057+
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
10581058
if (failed(sgOffsets))
10591059
return failure();
10601060

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,3 +890,19 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
890890
xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
891891
return
892892
}
893+
894+
// -----
895+
func.func @simt_store_matrix_vector_nonlinear(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>, %arg1: vector<2x16xf32>) {
896+
// expected-error@+1 {{With subgroup_block_io, lane data must be linear}}
897+
xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
898+
vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>
899+
return
900+
}
901+
902+
// -----
903+
func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>, %arg1: vector<16x2xf32>) {
904+
// expected-error@+1 {{With subgroup_block_io, lane data must be coalesced}}
905+
xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} :
906+
vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>
907+
return
908+
}

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,9 @@ gpu.module @xevm_module{
321321
gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>) {
322322
%c0 = arith.constant 0 : index
323323
%c1 = arith.constant 1 : index
324-
%1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
324+
%1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
325325
!xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>
326-
xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
326+
xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
327327
vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
328328
gpu.return
329329
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
201201
Value sgId =
202202
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
203203
auto maybeOffsets =
204-
sliceAttr.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
204+
sliceAttr.computeDistributedCoords(rewriter, loc, sgId, wgShape);
205205
if (failed(maybeOffsets))
206206
return failure();
207207

0 commit comments

Comments
 (0)