Skip to content

Commit 07f9f9f

Browse files
committed
fix
1 parent 491625d commit 07f9f9f

File tree

2 files changed

+16
-43
lines changed

2 files changed

+16
-43
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -341,24 +341,6 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
341341
return rewriter.notifyMatchFailure(
342342
descOp, "the tensor descriptor lacks sg_map attribute");
343343

344-
auto layout = sgMap.getWiLayout();
345-
346-
// Calculate the offset within tensor descriptor for the current lane_id. The
347-
// access to proper element for a work item is done through a lane-specific
348-
// subview (tdesc offsets are used as base, lane shift is added on top).
349-
auto laneid = warpOp.getLaneid();
350-
auto xDim =
351-
rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
352-
auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
353-
auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
354-
355-
auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
356-
descOffsets[0]);
357-
auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
358-
descOffsets[1]);
359-
auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
360-
auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
361-
362344
auto distributedDescTypeOrFailure = getDistributedTensorDescType(
363345
descOp.getType(), sgMap, descOp.getType().getMemorySpace());
364346
if (failed(distributedDescTypeOrFailure))
@@ -378,15 +360,10 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
378360
newRetIndices);
379361

380362
rewriter.setInsertionPointAfter(newWarpOp);
381-
auto subview = rewriter.create<memref::SubViewOp>(
382-
newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
383-
overwriteSizes, overwriteStrides);
384-
subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
385-
386-
auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
387363
auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
388-
newWarpOp.getLoc(), newTDescType, subview,
389-
getAsOpFoldResult({zero, zero}));
364+
newWarpOp.getLoc(), newTDescType,
365+
dyn_cast<TypedValue<MemRefType>>(newWarpOp.getResult(newRetIndices[0])),
366+
descOffsets);
390367

391368
Value distributedVal = newWarpOp.getResult(operandIdx);
392369
rewriter.replaceAllUsesWith(distributedVal, newDescOp);

mlir/test/Dialect/XeGPU/xegpu-distribute.mlir

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
// CHECK: %[[res:.*]]:2 = gpu.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}}, %{{.*}} : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
99
// CHECK-SAME: -> (!xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x2xf16>)
1010
// CHECK: ^bb0(%[[src:.*]]: vector<24x32xf16>, %[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
11-
// CHECK: gpu.yield%[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x32xf16>
11+
// CHECK: gpu.yield %[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x32xf16>
1212
// CHECK: xegpu.store_nd %[[res]]#1, %[[res]]#0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> :
13-
// CHECK-SAME: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
13+
// CHECK: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
1414

1515
func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> () {
1616
%laneid = gpu.lane_id
@@ -23,7 +23,6 @@ func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tens
2323
}
2424

2525
// -----
26-
2726
#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
2827
#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
2928

@@ -33,7 +32,7 @@ func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tens
3332
// CHECK-SAME: -> (vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
3433
// CHECK: ^bb0(%[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
3534
// CHECK: %[[dead:.*]] = xegpu.load_nd
36-
// CHECK: gpu.yield%[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
35+
// CHECK: gpu.yield %[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
3736
// CHECK: %[[load:.*]] = xegpu.load_nd %[[res]]#1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
3837
// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<24x2xf16>
3938
// CHECK: return %[[load]]
@@ -56,26 +55,23 @@ func.func @test_load_nd_distribution(%dst: !xegpu.tensor_desc<24x32xf16, #blk_td
5655
#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
5756

5857
// CHECK-LABEL: test_create_nd_desc_distribution
59-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
58+
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
6059
// CHECK: %[[laneid:.*]] = gpu.lane_id
6160
// CHECK: %[[res:.*]]:2 = gpu.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : memref<24x32xf16>)
62-
// CHECK-SAME: -> (!xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>)
61+
// CHECK-SAME: -> (!xegpu.tensor_desc<12x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>)
6362
// CHECK: ^bb0(%[[dst:.*]]: memref<24x32xf16>)
6463
// CHECK: %[[dead:.*]] = xegpu.create_nd_tdesc
65-
// CHECK: gpu.yield%[[dead]], %[[dst]] :
66-
// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>
67-
// CHECK: %[[view:.*]] = memref.subview %[[res]]#1[%[[C0]], %[[laneid]]] [24, 2] [1, 1] : memref<24x32xf16> to memref<24x2xf16, strided<[32, 1], offset: ?>>
68-
// CHECK: %[[desc:.*]] = xegpu.create_nd_tdesc %[[view]][0, 0] : memref<24x2xf16, strided<[32, 1], offset: ?>>
69-
// CHECK-SAME: -> !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
70-
// CHECK: return %[[desc]]
64+
// CHECK: gpu.yield %[[dead]], %[[dst]] :
65+
7166

72-
func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
67+
func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>) {
7368
%laneid = gpu.lane_id
69+
%c12 = arith.constant 12 : index
7470
%r = gpu.warp_execute_on_lane_0(%laneid)[16]
75-
args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
71+
args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>) {
7672
^bb0(%arg0: memref<24x32xf16>):
77-
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
78-
gpu.yield%0 : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
73+
%0 = xegpu.create_nd_tdesc %arg0[%c12, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>
74+
gpu.yield%0 : !xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>
7975
}
80-
return %r : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
76+
return %r : !xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>
8177
}

0 commit comments

Comments
 (0)