Skip to content

Commit e1a920b

Browse files
committed
save work
1 parent 14b4ff0 commit e1a920b

File tree

2 files changed

+47
-48
lines changed

2 files changed

+47
-48
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,15 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
189189
return scatter_attr.getChunkSize().getInt();
190190
return 1;
191191
}
192+
193+
/// Helper to drop all layout information from the TensorDesc type.
194+
TensorDescType dropLayouts() {
195+
if (getLayoutAttr() == xegpu::LayoutAttr())
196+
return *this;
197+
198+
return get(getContext(), getShape(), getElementType(), getEncoding(),
199+
xegpu::LayoutAttr());
200+
}
192201
}];
193202

194203
let hasCustomAssemblyFormat = true;

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 38 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -884,18 +884,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
884884
return VectorType::get(distributedShape, originalType.getElementType());
885885
}
886886

887-
// Drop the layout attribute from the tensor descriptor type if layout is
888-
// present.
889-
static xegpu::TensorDescType dropLayouts(xegpu::TensorDescType tensorDesc) {
890-
if (tensorDesc.getLayoutAttr() == xegpu::LayoutAttr())
891-
return tensorDesc;
892-
893-
return xegpu::TensorDescType::get(
894-
tensorDesc.getContext(), tensorDesc.getShape(),
895-
tensorDesc.getElementType(), tensorDesc.getEncoding(),
896-
xegpu::LayoutAttr());
897-
}
898-
899887
/// Helper function to resolve types if the distributed type out of
900888
/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
901889
/// Example 1:
@@ -1042,12 +1030,12 @@ struct MoveFuncBodyToWarpExecuteOnLane0
10421030
/// Example:
10431031
///
10441032
/// ```
1045-
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1033+
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
10461034
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1047-
/// (!xegpu.tensor_desc<4x8xf32, #lo0>) {
1035+
/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
10481036
/// ...
10491037
/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
1050-
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #lo0>
1038+
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
10511039
/// vector.yield %td
10521040
/// }
10531041
/// ```
@@ -1056,7 +1044,7 @@ struct MoveFuncBodyToWarpExecuteOnLane0
10561044
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
10571045
/// ...
10581046
/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
1059-
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #lo0>
1047+
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
10601048
/// vector.yield %arg0, %dead
10611049
/// }
10621050
/// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
@@ -1099,8 +1087,8 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
10991087
}
11001088
rewriter.setInsertionPointAfter(newWarpOp);
11011089
xegpu::TensorDescType distributedTensorDescTy =
1102-
dropLayouts(descOp.getType()); // Distributed tensor descriptor type
1103-
// does not contain layout info.
1090+
descOp.getType().dropLayouts(); // Distributed tensor descriptor type
1091+
// does not contain layout info.
11041092
auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
11051093
newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
11061094
descOp->getAttrs());
@@ -1120,23 +1108,23 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
11201108
/// Example:
11211109
///
11221110
/// ```
1123-
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1111+
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
11241112
/// gpu.warp_execute_on_lane_0(%laneid) -> () {
11251113
/// ...
11261114
/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
1127-
/// !xegpu.tensor_desc<4x8xf32, #lo0>
1115+
/// !xegpu.tensor_desc<4x8xf32, #layout0>
11281116
/// }
11291117
/// ```
11301118
/// To
11311119
/// ```
11321120
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1133-
/// !xegpu.tensor_desc<4x8xf32, #lo0>) {
1121+
/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
11341122
/// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
1135-
/// #lo0>
1123+
/// #layout0>
11361124
/// }
11371125
/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
11381126
/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1139-
/// #lo0>
1127+
/// #layout0>
11401128
/// -> !xegpu.tensor_desc<4x8xf32>
11411129
/// xegpu.store_nd %0, %1: vector<4xf32>,
11421130
/// !xegpu.tensor_desc<4x8xf32>
@@ -1195,7 +1183,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
11951183
// For the tensor descriptor operand, the layout attribute is dropped after
11961184
// distribution. Types needs to be resolved in this case also.
11971185
xegpu::TensorDescType distributedTensorDescTy =
1198-
dropLayouts(storeOp.getTensorDescType());
1186+
storeOp.getTensorDescType().dropLayouts();
11991187
newStoreOperands.push_back(
12001188
resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
12011189
distributedTensorDescTy, rewriter));
@@ -1220,25 +1208,26 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
12201208
/// Example:
12211209
///
12221210
/// ```
1223-
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1211+
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
12241212
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
12251213
/// (vector<4x1xf32>) {
12261214
/// ...
1227-
/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #lo0> ->
1215+
/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
1216+
/// ->
12281217
/// vector<4x8xf32>
12291218
/// gpu.yield %ld
12301219
/// }
12311220
/// ```
12321221
/// To
12331222
/// ```
12341223
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1235-
/// !xegpu.tensor_desc<4x8xf32, #lo0>) {
1224+
/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
12361225
/// ...
1237-
/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #lo0> ->
1226+
/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
12381227
/// vector<4x8xf32> gpu.yield %dead, %arg0
12391228
/// }
12401229
/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1241-
/// #lo0> -> !xegpu.tensor_desc<4x8xf32>
1230+
/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
12421231
/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
12431232
/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
12441233
///
@@ -1279,9 +1268,9 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
12791268
return rewriter.notifyMatchFailure(
12801269
loadOp, "Failed to get distributed vector type for the load op");
12811270
xegpu::TensorDescType distributedTensorDescTy =
1282-
dropLayouts(loadOp.getTensorDescType()); // Distributed tensor
1283-
// descriptor type does not
1284-
// contain layout info.
1271+
loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
1272+
// descriptor type does not
1273+
// contain layout info.
12851274
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
12861275
newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
12871276
resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
@@ -1439,28 +1428,29 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
14391428
/// the distributed types does not match expected xegpu SIMT types.
14401429
/// Example:
14411430
/// ```
1442-
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1431+
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
14431432
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1444-
/// (!xegpu.tensor_desc<4x8xf32, #lo0>) {
1433+
/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
14451434
/// ...
14461435
/// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1447-
/// !xegpu.tensor_desc<4x8xf32, #lo0>
1436+
/// !xegpu.tensor_desc<4x8xf32, #layout0>
14481437
/// gpu.yield %update
14491438
/// }
14501439
/// ...
14511440
/// ```
14521441
/// To
14531442
/// ```
1454-
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1455-
/// !xegpu.tensor_desc<4x8xf32, #lo0>) {
1443+
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
1444+
/// !xegpu.tensor_desc<4x8xf32, #layout0>,
1445+
/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
14561446
/// ...
14571447
/// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1458-
/// !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
1459-
/// gup.yield %dead, %arg0, %c32, %c16
1448+
/// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
1449+
/// gpu.yield %dead, %arg0, %c32, %c16
14601450
/// }
14611451
/// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1462-
/// #lo0> -> !xegpu.tensor_desc<4x8xf32>
1463-
/// %1 = xegpu.update_nd_offset %0, [%c32, %c16]:
1452+
/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1453+
/// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
14641454
/// !xegpu.tensor_desc<4x8xf32>
14651455
/// ...
14661456
/// ```
@@ -1477,7 +1467,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
14771467
unsigned operandIdx = operand->getOperandNumber();
14781468
// new update op does not have layout attribute.
14791469
xegpu::TensorDescType newTensorDescTy =
1480-
dropLayouts(updateOp.getTensorDescType());
1470+
updateOp.getTensorDescType().dropLayouts();
14811471

14821472
SmallVector<Value, 3> newYieldValues;
14831473
SmallVector<Type, 3> newYieldTypes;
@@ -1523,20 +1513,20 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
15231513
/// Example:
15241514
///
15251515
/// ```
1526-
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1516+
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
15271517
/// gpu.warp_execute_on_lane_0(%laneid) -> () {
15281518
/// ...
1529-
/// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #lo0>
1519+
/// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
15301520
/// }
15311521
/// ```
15321522
/// To
15331523
/// ```
15341524
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
1535-
// !xegpu.tensor_desc<4x8xf32, #lo0>) {
1536-
/// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #lo0>
1525+
/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1526+
/// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
15371527
/// }
15381528
/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
1539-
/// #lo0> -> !xegpu.tensor_desc<4x8xf32>
1529+
/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
15401530
/// xegpu.prefetch_nd %0 : !xegpu.tensor_desc<4x8xf32>
15411531
///
15421532
/// ```
@@ -1563,7 +1553,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
15631553
// Create a new prefetch op outside the warp op with updated tensor
15641554
// descriptor type. Source tensor descriptor require type resolution.
15651555
xegpu::TensorDescType newTensorDescTy =
1566-
dropLayouts(prefetchOp.getTensorDescType());
1556+
prefetchOp.getTensorDescType().dropLayouts();
15671557
rewriter.setInsertionPointAfter(newWarpOp);
15681558
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
15691559
newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};

0 commit comments

Comments
 (0)