Skip to content

Commit b4f5a4d

Browse files
committed
Relax subgroup_block_io dimensionality restriction
1 parent f80ee32 commit b4f5a4d

File tree

4 files changed

+45
-30
lines changed

4 files changed

+45
-30
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
562562
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
563563
if (!valOrResVecTy)
564564
valOrResVecTy = VectorType::get(1, data.getType());
565+
if (valOrResVecTy.getShape().size() != 1)
566+
return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
565567

566568
int64_t elemBitWidth =
567569
valOrResVecTy.getElementType().getIntOrFloatBitWidth();

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
181181
if (!dataTy) {
182182
if (subgroup_block_io)
183183
return emitError() << "subgroup_block_io "
184-
"are only allowed when result is a 1D VectorType.";
184+
"are only allowed when result is a VectorType.";
185185
else
186186
return success();
187187
}
@@ -193,9 +193,6 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
193193
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
194194

195195
if (dataShape.size() == 2) {
196-
if (subgroup_block_io)
197-
return emitError() << "subgroup_block_io "
198-
"are only allowed when result is a 1D VectorType.";
199196
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
200197
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
201198
return emitError() << "data shape must not exceed mem_desc shape.";

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
451451
%offsets = arith.constant dense<[0]> : vector<1xindex>
452452
%mask = arith.constant dense<1>: vector<1xi1>
453453
// expected-error@+1 {{Mask should match value except the chunk size dim}}
454-
xegpu.store %val, %src[%offsets], %mask
454+
xegpu.store %val, %src[%offsets], %mask
455455
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
456456
return
457457
}
@@ -870,14 +870,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
870870
return
871871
}
872872

873-
// -----
874-
func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
875-
// expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
876-
%data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
877-
return
878-
}
879-
880-
881873
// -----
882874
func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
883875
// expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}}
@@ -898,18 +890,3 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
898890
xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
899891
return
900892
}
901-
902-
// -----
903-
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
904-
// expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
905-
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
906-
return
907-
}
908-
909-
// -----
910-
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
911-
// expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
912-
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
913-
return
914-
}
915-

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ gpu.module @xevm_module{
271271
// CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index
272272
// CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index
273273
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
274-
// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%0]
275-
// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%0]
274+
// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
275+
// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
276276
// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]]
277277
// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]]
278278
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
@@ -285,3 +285,42 @@ gpu.module @xevm_module{
285285
gpu.return
286286
}
287287
}
288+
289+
// -----
290+
// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) {
291+
// CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index
292+
// CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index
293+
// CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index
294+
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
295+
// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
296+
// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
297+
// CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]]
298+
// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]]
299+
// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
300+
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
301+
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
302+
gpu.module @xevm_module{
303+
gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
304+
%c0 = arith.constant 0 : index
305+
%1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
306+
xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
307+
gpu.return
308+
}
309+
}
310+
311+
// -----
312+
// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) {
313+
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
314+
// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 16], stride = [1, 32]>>, index, index -> vector<2x1xf32>
315+
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
316+
// CHECK-SAME: vector<2x1xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 16], stride = [1, 32]>>, index, index
317+
gpu.module @xevm_module{
318+
gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>) {
319+
%c0 = arith.constant 0 : index
320+
%1 = xegpu.load_matrix %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
321+
!xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>
322+
xegpu.store_matrix %1, %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
323+
vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
324+
gpu.return
325+
}
326+
}

0 commit comments

Comments
 (0)