Skip to content

Commit 10448e1

Browse files
committed
Improve verification
1 parent 246761e commit 10448e1

File tree

4 files changed

+42
-22
lines changed

4 files changed

+42
-22
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
651651
/// multiple blocks according to round-robin distribution rules.
652652

653653
FailureOr<SmallVector<SmallVector<Value>>>
654-
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape);
654+
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
655655

656656
/// Check if this is slice of some other layout.
657657
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,26 +191,38 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
191191

192192
ArrayRef<int64_t> dataShape = dataTy.getShape();
193193
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
194+
195+
SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
196+
ArrayAttr strideAttr = mdescTy.getStrideAttr();
197+
SmallVector<int64_t> strides;
198+
for (Attribute attr : strideAttr.getValue()) {
199+
strides.push_back(cast<IntegerAttr>(attr).getInt());
200+
}
194201
if (subgroup_block_io && layout) {
195202
auto laneData = layout.getEffectiveLaneDataAsInt();
203+
auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
196204
if (!laneData.empty()) {
197-
bool isLaneDataLinear =
205+
bool isLaneDataContiguous =
198206
std::all_of(laneData.begin(), std::prev(laneData.end()),
199207
[](int x) { return x == 1; });
200-
if (!isLaneDataLinear)
201-
return emitError()
202-
<< "With subgroup_block_io, lane data must be linear.";
203-
if (isLaneDataLinear && laneData.back() != 1)
204-
return emitError()
205-
<< "With subgroup_block_io, lane data must be coalesced.";
208+
if (!isLaneDataContiguous)
209+
return emitError() << "With subgroup_block_io, accessed data must be "
210+
"contiguous and coalesced.";
211+
for (int i = 0; i < laneData.size(); ++i) {
212+
if (laneLayout[i] != blockShape[i])
213+
return emitError() << "With subgroup_block_io, the block shape must "
214+
"match the lane layout.";
215+
if (laneLayout[i] != 1 && strides[i] != 1)
216+
return emitError() << "With subgroup_block_io, the distributed "
217+
"dimensions must be contiguous.";
218+
}
206219
}
207220
}
208221
if (dataShape.size() == 2) {
209222
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
210223
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
211224
return emitError() << "data shape must not exceed mem_desc shape.";
212225
} else {
213-
SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
214226
// if the subgroup_block_io attribute is set, mdescTy must have block
215227
// attribute
216228
if (subgroup_block_io && !blockShape.size())

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -892,17 +892,25 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
892892
}
893893

894894
// -----
895-
func.func @simt_store_matrix_vector_nonlinear(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>, %arg1: vector<2x16xf32>) {
896-
// expected-error@+1 {{With subgroup_block_io, lane data must be linear}}
895+
func.func @simt_store_matrix_vector_nonlinear(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>>, %arg1: vector<2x16xf32>) {
896+
// expected-error@+1 {{With subgroup_block_io, accessed data must be contiguous and coalesced}}
897897
xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
898-
vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>
898+
vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>>
899899
return
900900
}
901901

902902
// -----
903-
func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>, %arg1: vector<16x2xf32>) {
904-
// expected-error@+1 {{With subgroup_block_io, lane data must be coalesced}}
903+
func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>>, %arg1: vector<16x2xf32>) {
904+
// expected-error@+1 {{With subgroup_block_io, the distributed dimensions must be contiguous}}
905905
xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} :
906-
vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>
906+
vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>>
907+
return
908+
}
909+
910+
// -----
911+
func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>, %arg1: vector<16x2xf32>) {
912+
// expected-error@+1 {{With subgroup_block_io, the block shape must match the lane layout}}
913+
xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
914+
vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>
907915
return
908916
}

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,17 +314,17 @@ gpu.module @xevm_module{
314314
// -----
315315
// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) {
316316
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
317-
// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 16], stride = [1, 32]>>, index, index -> vector<2x1xf32>
317+
// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index -> vector<1x2xf32>
318318
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
319-
// CHECK-SAME: vector<2x1xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 16], stride = [1, 32]>>, index, index
319+
// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index
320320
gpu.module @xevm_module{
321-
gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>) {
321+
gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>) {
322322
%c0 = arith.constant 0 : index
323323
%c1 = arith.constant 1 : index
324-
%1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
325-
!xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>
326-
xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
327-
vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
324+
%1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
325+
!xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>
326+
xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
327+
vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index
328328
gpu.return
329329
}
330330
}

0 commit comments

Comments
 (0)