Skip to content

Commit 326a1a4

Browse files
authored
[MLIR][XeGPU] Add anchor_layout and update propagation to honor user-specified layouts (#169267)
Introduce anchor layout for XeGPU anchor ops: load_nd, store_nd, prefetch_nd, dpas, load, store, prefetch, load_matrix, store_matrix, and atomic_rmw. Anchor layout is permanent, and is guaranteed to be honored by XeGPU distribution and lowerinngs once specified. 1. Add anchor_layout for XeGPU anchor OPs: load_nd, store_nd, prefetch_nd, dpas, load, store, prefetch, load_matrix, store_matrix, and atomic_rmw. 2. rename layout attributes to anchor_layout for these ops: load, store, load_matrix, store_matrix 3. update layout propagation pass: Only when user doesn't specify anchor layout, the pass computes a default layout and set to anchor op's permant layout and use that for propagation. if user specified anchor layout, the pass takes user-specified anchor layout. permant layout and use that for propagation. if user specified anchor layout, the pass takes user-specified anchor layout.
1 parent a9cc7fe commit 326a1a4

File tree

7 files changed

+672
-362
lines changed

7 files changed

+672
-362
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 329 additions & 119 deletions
Large diffs are not rendered by default.

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
465465
xegpu::CachePolicyAttr l3_hint) {
466466

467467
return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
468-
l1_hint, l2_hint, l3_hint);
468+
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
469469
}
470470

471471
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
@@ -480,7 +480,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
480480
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
481481

482482
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
483-
l2_hint, l3_hint);
483+
l2_hint, l3_hint, /*anchor_layout=*/nullptr);
484484
}
485485

486486
LogicalResult PrefetchNdOp::verify() {
@@ -519,7 +519,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
519519

520520
return build(builder, state, retType, tensorDesc, ValueRange(),
521521
DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
522-
l3_hint);
522+
l3_hint, /*anchor_layout=*/nullptr);
523523
}
524524

525525
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
@@ -535,7 +535,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
535535
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
536536

537537
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
538-
packed, transpose, l1_hint, l2_hint, l3_hint);
538+
packed, transpose, l1_hint, l2_hint, l3_hint,
539+
/*anchor_layout=*/nullptr);
539540
}
540541

541542
LogicalResult LoadNdOp::verify() {
@@ -638,7 +639,8 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
638639
xegpu::CachePolicyAttr l3_hint) {
639640

640641
return build(builder, state, value, tensorDesc, ValueRange(),
641-
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
642+
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint,
643+
/*anchor_layout=*/nullptr);
642644
}
643645

644646
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
@@ -653,7 +655,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
653655
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
654656

655657
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
656-
l1_hint, l2_hint, l3_hint);
658+
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
657659
}
658660

659661
LogicalResult StoreNdOp::verify() {
@@ -826,7 +828,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
826828
xegpu::CachePolicyAttr l2_hint,
827829
xegpu::CachePolicyAttr l3_hint) {
828830
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
829-
IntegerAttr{});
831+
IntegerAttr{}, /*anchor_layout=*/nullptr);
830832
}
831833

832834
//===----------------------------------------------------------------------===//
@@ -876,7 +878,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
876878
xegpu::CachePolicyAttr l2_hint,
877879
xegpu::CachePolicyAttr l3_hint) {
878880
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
879-
l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
881+
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
880882
}
881883

882884
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -892,7 +894,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
892894
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
893895

894896
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
895-
l2_hint, l3_hint, /*layout=*/nullptr);
897+
l2_hint, l3_hint, /*anchor_layout=*/nullptr);
896898
}
897899

898900
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -960,7 +962,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
960962
xegpu::CachePolicyAttr l2_hint,
961963
xegpu::CachePolicyAttr l3_hint) {
962964
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
963-
l2_hint, l3_hint, /*layout=*/nullptr);
965+
l2_hint, l3_hint, /*anchor_layout=*/nullptr);
964966
}
965967

966968
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -978,7 +980,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
978980

979981
// Call the correct builder overload that does not expect result types.
980982
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
981-
l3_hint, /*layout=*/nullptr);
983+
l3_hint, /*anchor_layout=*/nullptr);
982984
}
983985

984986
void StoreScatterOp::build(

0 commit comments

Comments
 (0)