Skip to content

Commit 101b67a

Browse files
committed
Revert "Remove base_pitch and use number of elements for base_width and base_height."
This reverts commit 7e1514a.
1 parent 7e1514a commit 101b67a

File tree

4 files changed

+55
-39
lines changed

4 files changed

+55
-39
lines changed

mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def XeVM_BlockLoad2dOp
191191
: XeVM_Op<"blockload2d">,
192192
Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>,
193193
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
194-
I32:$base_height, I32:$x, I32:$y, I32Attr:$elem_size_in_bits,
195-
I32Attr:$tile_width, I32Attr:$tile_height, I32Attr:$v_blocks,
196-
I1Attr:$transpose, I1Attr:$pack_register,
194+
I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
195+
I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
196+
I32Attr:$v_blocks, I1Attr:$transpose, I1Attr:$pack_register,
197197
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
198198

199199
let summary = "2D block load";
@@ -202,7 +202,9 @@ def XeVM_BlockLoad2dOp
202202
The `xevm.blockload2d` operation loads a two dimensional matrix tile
203203
from a base matrix residing in global memory. The parameters are:
204204
$ptr - the base address of the base matrix containing the tile to load
205-
$base_width, $base_height, the shape of the base matrix in number of elements.
205+
$base_width, $base_height, $base_pitch - the shape of the base matrix.
206+
pitch is the physical stride between the first columns of the current row
207+
and the subsequent row. All units are in bytes.
206208
$x, $y, $tile_width, $tile_height - the starting offsets and shape of
207209
the tile to load in number of elements.
208210
$elem_size_in_bits - the size in bits of the matrix element type
@@ -225,9 +227,10 @@ def XeVM_BlockLoad2dOp
225227
```mlir
226228
%base_width_a = arith.constant 32 : i32
227229
%base_height_a = arith.constant 8 : i32
230+
%base_pitch_a = arith.constant 32 : i32
228231
%x = arith.constant 0 : i32
229232
%y = arith.constant 0 : i32
230-
%loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %x, %y
233+
%loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
231234
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32,
232235
v_blocks=1 : i32, transpose=false : i32, pack_register=false,
233236
cache_control=#xevm.load_cache_control<Default>}>
@@ -248,8 +251,8 @@ def XeVM_BlockLoad2dOp
248251
def XeVM_BlockStore2dOp
249252
: XeVM_Op<"blockstore2d">,
250253
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr, I32:$base_width,
251-
I32:$base_height, I32:$x, I32:$y, I32Attr:$elem_size_in_bits,
252-
I32Attr:$tile_width, I32Attr:$tile_height,
254+
I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
255+
I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
253256
FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$stored_val,
254257
OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
255258

@@ -259,9 +262,11 @@ def XeVM_BlockStore2dOp
259262
The `xevm.blockstore2d` operation stores a two dimensional tile into a
260263
larger matrix residing in global memory. The parameters are:
261264
$ptr - the base address of the target matrix where to store the tile
262-
$base_width, $base_height, the shape of the target matrix in number of elements.
265+
$base_width, $base_height, $base_pitch - the shape of the target matrix. pitch is the
266+
physical stride between the first columns of the current row and the subsequent row.
267+
All units are in bytes.
263268
$x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to store
264-
in number of elements.
269+
in number of elements.
265270
$elem_size_in_bits - the size in bits of the matrix element
266271
- 32 for f32, tf32
267272
- 16 for f16, int16, bf16
@@ -273,9 +278,10 @@ def XeVM_BlockStore2dOp
273278
```mlir
274279
%base_width_c = arith.constant 64 : i32
275280
%base_height_c = arith.constant 8 : i32
281+
%base_pitch_c = arith.constant 64 : i32
276282
%x = arith.constant 0 : i32
277283
%y = arith.constant 0 : i32
278-
xevm.blockstore2d %dst, %base_width_c, %base_height_c, %x, %y, %src
284+
xevm.blockstore2d %dst, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %src
279285
<{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32,
280286
cache_control=#xevm.load_cache_control<Default>}>
281287
: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
@@ -377,8 +383,9 @@ def XeVM_PrefetchOp
377383
def XeVM_BlockPrefetch2dOp
378384
: XeVM_Op<"blockprefetch2d">,
379385
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
380-
I32:$base_height, I32:$x, I32:$y, I32Attr:$elem_size_in_bits,
381-
I32Attr:$tile_width, I32Attr:$tile_height, I32Attr:$v_blocks,
386+
I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
387+
I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
388+
I32Attr:$v_blocks,
382389
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
383390

384391
let summary = "2D block prefetch";
@@ -387,7 +394,9 @@ def XeVM_BlockPrefetch2dOp
387394
The `xevm.blockprefetch2d` operation prefetches a two dimensional tile
388395
from a larger base matrix residing in global memory. The parameters are:
389396
$ptr - the base address of the base matrix containing the tile to prefetch
390-
$base_width, $base_height - the shape of the base matrix in number of elements.
397+
$base_width, $base_height, $base_pitch - the shape of the base matrix.
398+
pitch is the physical stride between the first columns of the current row
399+
and the subsequent row. All units are in bytes.
391400
$x, $y, $tile_width, $tile_height - the starting offsets and shape of tile
392401
to prefetch in number of elements.
393402
$elem_size_in_bits - the size in bits of the matrix element
@@ -399,7 +408,7 @@ def XeVM_BlockPrefetch2dOp
399408

400409
Example:
401410
```mlir
402-
xevm.blockprefetch2d %ptr, %base_width, %base_height, %x, %y
411+
xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y
403412
<{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32,
404413
v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
405414
: (!llvm.ptr<1>, i32, i32, i32, i32, i32)

mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ LogicalResult verifyMatrixInput(Op op) {
2828
static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
2929
BlockPrefetch2dOp>::value,
3030
"Unexpected template parameter");
31+
32+
std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
33+
std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
34+
if (pitch && width && *pitch < *width)
35+
return op->emitOpError(
36+
"4th operand (base pitch) should be >= 2nd operand (base width)");
37+
3138
uint32_t elemSize = op.getElemSizeInBits();
3239
if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
3340
return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,25 +1909,25 @@ llvm.func @invalid_xevm_mma(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8
19091909

19101910
// -----
19111911

1912-
llvm.func @invalid_xevm_matrix_1(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
1912+
llvm.func @invalid_xevm_matrix_1(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
19131913
// expected-error@+1 {{op expecting tile_width to be between 1 and 8}}
1914-
xevm.blockstore2d %c, %base_width_c, %base_height_c, %x, %y, %c_result_casted <{elem_size_in_bits=64 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, vector<8xi32>)
1914+
xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=64 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
19151915
llvm.return
19161916
}
19171917

19181918
// -----
19191919

1920-
llvm.func @invalid_xevm_matrix_2(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
1920+
llvm.func @invalid_xevm_matrix_2(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
19211921
// expected-error@+1 {{op expecting elem_size_in_bits to be 8, 16, 32, or 64}}
1922-
xevm.blockstore2d %c, %base_width_c, %base_height_c, %x, %y, %c_result_casted <{elem_size_in_bits=18 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, vector<8xi32>)
1922+
xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=18 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
19231923
llvm.return
19241924
}
19251925

19261926
// -----
19271927

1928-
llvm.func @invalid_xevm_matrix_3(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
1928+
llvm.func @invalid_xevm_matrix_3(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
19291929
// expected-error@+1 {{op result size of 128 bits does not match the expected size of 208 bits}}
1930-
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=26 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32) -> vector<8xi16>
1930+
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=26 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
19311931
llvm.return %loaded_a : vector<8xi16>
19321932
}
19331933

mlir/test/Dialect/LLVMIR/xevm.mlir

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,59 +2,59 @@
22

33
// CHECK-LABEL: func.func @blockload2d(
44
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>,
5-
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32)
5+
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
66
func.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32,
7-
%x: i32, %y: i32) -> vector<8xi16> {
8-
// CHECK: %[[VAR0:.*]] = xevm.blockload2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]
7+
%base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
8+
// CHECK: %[[VAR0:.*]] = xevm.blockload2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
99
// CHECK-DAG: elem_size_in_bits = 16 : i32
1010
// CHECK-DAG: tile_width = 16 : i32
1111
// CHECK-DAG: tile_height = 8 : i32
1212
// CHECK-DAG: v_blocks = 1 : i32
1313
// CHECK-DAG: transpose = false
1414
// CHECK-DAG: pack_register = false
1515
// CHECK-DAG: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
16-
// CHECK: (!llvm.ptr<1>, i32, i32, i32, i32) -> vector<8xi16>
17-
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %x, %y
16+
// CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
17+
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
1818
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32,
1919
transpose=false, pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
20-
: (!llvm.ptr<1>, i32, i32, i32, i32) -> vector<8xi16>
20+
: (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
2121
return %loaded_a : vector<8xi16>
2222
}
2323

2424
// -----
2525
// CHECK-LABEL: func.func @blockstore2d(
2626
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>,
27-
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32,
28-
// CHECK-SAME: %[[ARG5:.*]]: vector<8xi32>)
27+
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32,
28+
// CHECK-SAME: %[[ARG6:.*]]: vector<8xi32>)
2929
func.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32,
30-
%x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
31-
// CHECK: xevm.blockstore2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
30+
%base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
31+
// CHECK: xevm.blockstore2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]]
3232
// CHECK-DAG: elem_size_in_bits = 32 : i32
3333
// CHECK-DAG: tile_width = 16 : i32
3434
// CHECK-DAG: tile_height = 8 : i32
35-
// CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, vector<8xi32>)
36-
xevm.blockstore2d %c, %base_width_c, %base_height_c, %x, %y, %c_result_casted
35+
// CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
36+
xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
3737
<{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}>
38-
: (!llvm.ptr<1>, i32, i32, i32, i32, vector<8xi32>)
38+
: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
3939
return
4040
}
4141

4242
// -----
4343
// CHECK-LABEL: func.func @blockprefetch2d(
4444
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>,
45-
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32)
45+
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
4646
func.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32,
47-
%x: i32, %y: i32) {
48-
// CHECK: xevm.blockprefetch2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]
47+
%base_pitch: i32, %x: i32, %y: i32) {
48+
// CHECK: xevm.blockprefetch2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
4949
// CHECK-DAG: elem_size_in_bits = 8 : i32
5050
// CHECK-DAG: tile_width = 32 : i32
5151
// CHECK-DAG: tile_height = 8 : i32
5252
// CHECK-DAG: v_blocks = 1 : i32
5353
// CHECK-DAG: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
54-
// CHECK: (!llvm.ptr<1>, i32, i32, i32, i32)
55-
xevm.blockprefetch2d %ptr, %base_width, %base_height, %x, %y <{elem_size_in_bits=8 : i32,
54+
// CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32)
55+
xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32,
5656
tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32,
57-
cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32)
57+
cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
5858
return
5959
}
6060

0 commit comments

Comments
 (0)