Skip to content

Commit 966525b

Browse files
committed
remove vector direction and lenght attirbutes
1 parent 0344761 commit 966525b

File tree

7 files changed

+22
-72
lines changed

7 files changed

+22
-72
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -724,22 +724,4 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
724724

725725
}
726726

727-
def RowOriented : I32EnumAttrCase<"ROW", 0, "row">;
728-
def ColOriented : I32EnumAttrCase<"COL", 1, "col">;
729-
def MatrixAccessDirection :
730-
I32EnumAttr<"MatrixAccessDirection",
731-
"Matrix elements/vectors can have row or column direction", [
732-
RowOriented, ColOriented
733-
]> {
734-
let genSpecializedAttr = 0;
735-
let cppNamespace = "::mlir::xegpu";
736-
}
737-
def MatrixAccessDirectionAttr :
738-
EnumAttr<XeGPU_Dialect,
739-
MatrixAccessDirection,
740-
"matrix_access_direction">{
741-
let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}];
742-
let assemblyFormat = "`<` $value `>`";
743-
}
744-
745727
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,8 +1302,6 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13021302
let arguments = (ins XeGPU_MemDesc:$mem_desc,
13031303
Variadic<Index>: $offsets,
13041304
DenseI64ArrayAttr: $const_offsets,
1305-
OptionalAttr<I32Attr>:$vec_length,
1306-
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
13071305
OptionalAttr<UnitAttr>:$subgroup_block_io,
13081306
OptionalAttr<DistributeLayoutAttr>:$layout
13091307
);
@@ -1355,8 +1353,6 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
13551353
XeGPU_MemDesc:$mem_desc,
13561354
Variadic<Index>: $offsets,
13571355
DenseI64ArrayAttr: $const_offsets,
1358-
OptionalAttr<I32Attr>:$vec_length,
1359-
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
13601356
OptionalAttr<UnitAttr>:$subgroup_block_io,
13611357
OptionalAttr<DistributeLayoutAttr>:$layout
13621358
);

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ class CreateNdDescToXeVMPattern
184184
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
185185
// Descriptor shape is expected to be 2D.
186186
int64_t rank = mixedSizes.size();
187-
if (rank != 2) {
187+
if (rank != 2)
188188
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
189-
}
189+
190190
auto sourceTy = source.getType();
191191
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
192192
// If source is a memref, we need to extract the aligned pointer as index.
@@ -658,20 +658,6 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
658658
return rewriter.notifyMatchFailure(
659659
op, "The lowering is specific to pvc or bmg.");
660660
}
661-
xegpu::MatrixAccessDirectionAttr vecDirection =
662-
op.getVecDirectionAttr();
663-
if (vecDirection &&
664-
vecDirection.getValue() == xegpu::MatrixAccessDirection::COL &&
665-
!mdescTy.isColMajor())
666-
return rewriter.notifyMatchFailure(
667-
op, "mem_desc should be column major when "
668-
"vec_direction is COLUMN for 1D result.");
669-
if (vecDirection &&
670-
vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW &&
671-
mdescTy.isColMajor())
672-
return rewriter.notifyMatchFailure(
673-
op, "mem_desc should be row major when "
674-
"vec_direction is ROW for 1D result.");
675661

676662
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
677663
Value loadOp =

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -173,17 +173,18 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
173173
return success();
174174
}
175175

176-
LogicalResult IsValidStoreMatrixParams(
177-
VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io,
178-
MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength,
179-
function_ref<InFlightDiagnostic()> emitError) {
180-
181-
if (!dataTy)
182-
if (subgroup_block_io || vecDirection || vecLength)
183-
return emitError() << "vec_length, vec_direction and subgroup_block_io "
176+
LogicalResult
177+
IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy,
178+
UnitAttr subgroup_block_io,
179+
function_ref<InFlightDiagnostic()> emitError) {
180+
181+
if (!dataTy) {
182+
if (subgroup_block_io)
183+
return emitError() << "subgroup_block_io "
184184
"are only allowed when result is a 1D VectorType.";
185185
else
186186
return success();
187+
}
187188

188189
if (mdescTy.getRank() != 2)
189190
return emitError() << "mem_desc must be 2D.";
@@ -192,8 +193,8 @@ LogicalResult IsValidStoreMatrixParams(
192193
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
193194

194195
if (dataShape.size() == 2) {
195-
if (subgroup_block_io || vecDirection || vecLength)
196-
return emitError() << "vec_length, vec_direction and subgroup_block_io "
196+
if (subgroup_block_io)
197+
return emitError() << "subgroup_block_io "
197198
"are only allowed when result is a 1D VectorType.";
198199
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
199200
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
@@ -1097,20 +1098,16 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
10971098
// Call the generated builder with all parameters (including optional ones as
10981099
// nullptr/empty)
10991100
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1100-
/*vec_length=*/nullptr, /*vec_direction=*/nullptr,
11011101
/*subgroup_block_io=*/nullptr, layout);
11021102
}
11031103

11041104
LogicalResult LoadMatrixOp::verify() {
11051105

11061106
auto resTy = dyn_cast<VectorType>(getRes().getType());
11071107
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1108-
MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
1109-
IntegerAttr vecLength = getVecLengthAttr();
11101108
MemDescType mdescTy = getMemDesc().getType();
11111109

11121110
return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io,
1113-
vecDirection, vecLength,
11141111
[&]() { return emitError(); });
11151112
}
11161113

@@ -1126,19 +1123,15 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
11261123
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
11271124
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
11281125
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1129-
/*vec_length=*/nullptr, /*vec_direction=*/nullptr,
11301126
/*subgroup_block_io=*/nullptr, layout);
11311127
}
11321128

11331129
LogicalResult StoreMatrixOp::verify() {
11341130

11351131
auto dataTy = dyn_cast<VectorType>(getData().getType());
11361132
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1137-
MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
1138-
IntegerAttr vecLength = getVecLengthAttr();
11391133
MemDescType mdescTy = getMemDesc().getType();
11401134
return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io,
1141-
vecDirection, vecLength,
11421135
[&]() { return emitError(); });
11431136
}
11441137

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
139139

140140
%tid_x = gpu.thread_id x
141141
%c16 = arith.constant 16 : index
142-
%1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
142+
%1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
143143

144144
//CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3>
145-
xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
145+
xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
146146

147147
gpu.return %1: vector<8xf16>
148148
}

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -870,16 +870,9 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
870870
return
871871
}
872872

873-
// -----
874-
func.func @load_mem_desc_invalid_attr1(%arg0: !xegpu.mem_desc<16x64xf16>) {
875-
// expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
876-
%data1 = xegpu.load_matrix %arg0[8, 8]<{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
877-
return
878-
}
879-
880873
// -----
881874
func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
882-
// expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
875+
// expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
883876
%data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
884877
return
885878
}
@@ -908,14 +901,14 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
908901

909902
// -----
910903
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
911-
// expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
904+
// expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
912905
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
913906
return
914907
}
915908

916909
// -----
917910
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
918-
// expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
911+
// expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
919912
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
920913
return
921914
}

mlir/test/Dialect/XeGPU/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -855,8 +855,8 @@ gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #
855855

856856
// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
857857
gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
858-
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
859-
%data = xegpu.load_matrix %arg0[8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
858+
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
859+
%data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
860860
gpu.return
861861
}
862862

@@ -890,8 +890,8 @@ gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16,
890890

891891
// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
892892
gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
893-
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
894-
xegpu.store_matrix %arg1, %arg0[8, 8] <{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
893+
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
894+
xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
895895
gpu.return
896896
}
897897

0 commit comments

Comments
 (0)