@@ -191,9 +191,9 @@ def XeVM_BlockLoad2dOp
191191 : XeVM_Op<"blockload2d">,
192192 Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>,
193193 Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
194- I32:$base_height, I32:$x , I32:$y, I32Attr:$elem_size_in_bits ,
195- I32Attr:$tile_width , I32Attr:$tile_height , I32Attr:$v_blocks ,
196- I1Attr:$transpose, I1Attr:$pack_register,
194+ I32:$base_height, I32:$base_pitch , I32:$x, I32:$y ,
195+ I32Attr:$elem_size_in_bits , I32Attr:$tile_width , I32Attr:$tile_height ,
196+ I32Attr:$v_blocks, I1Attr:$transpose, I1Attr:$pack_register,
197197 OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
198198
199199 let summary = "2D block load";
@@ -202,7 +202,9 @@ def XeVM_BlockLoad2dOp
202202 The `xevm.blockload2d` operation loads a two dimensional matrix tile
203203 from a base matrix residing in global memory. The parameters are:
204204 $ptr - the base address of the base matrix containing the tile to load
205- $base_width, $base_height, the shape of the base matrix in number of elements.
205+ $base_width, $base_height, $base_pitch - the shape of the base matrix.
206+ pitch is the physical stride between the first columns of the current row
207+ and the subsequent row. All units are in bytes.
206208 $x, $y, $tile_width, $tile_height - the starting offsets and shape of
207209 the tile to load in number of elements.
208210 $elem_size_in_bits - the size in bits of the matrix element type
@@ -225,9 +227,10 @@ def XeVM_BlockLoad2dOp
225227 ```mlir
226228 %base_width_a = arith.constant 32 : i32
227229 %base_height_a = arith.constant 8 : i32
230+ %base_pitch_a = arith.constant 32 : i32
228231 %x = arith.constant 0 : i32
229232 %y = arith.constant 0 : i32
230- %loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %x, %y
233+ %loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %base_pitch_a, % x, %y
231234 <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32,
232235 v_blocks=1 : i32, transpose=false : i32, pack_register=false,
233236 cache_control=#xevm.load_cache_control<Default>}>
@@ -248,8 +251,8 @@ def XeVM_BlockLoad2dOp
248251def XeVM_BlockStore2dOp
249252 : XeVM_Op<"blockstore2d">,
250253 Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr, I32:$base_width,
251- I32:$base_height, I32:$x , I32:$y, I32Attr:$elem_size_in_bits ,
252- I32Attr:$tile_width, I32Attr:$tile_height,
254+ I32:$base_height, I32:$base_pitch , I32:$x, I32:$y ,
255+ I32Attr:$elem_size_in_bits, I32Attr:$ tile_width, I32Attr:$tile_height,
253256 FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$stored_val,
254257 OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
255258
@@ -259,9 +262,11 @@ def XeVM_BlockStore2dOp
259262 The `xevm.blockstore2d` operation stores a two dimensional tile into a
260263 larger matrix residing in global memory. The parameters are:
261264 $ptr - the base address of the target matrix where to store the tile
262- $base_width, $base_height, the shape of the target matrix in number of elements.
265+ $base_width, $base_height, $base_pitch - the shape of the target matrix. pitch is the
266+ physical stride between the first columns of the current row and the subsequent row.
267+ All units are in bytes.
263268 $x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to store
264- in number of elements.
269+ in number of elements.
265270 $elem_size_in_bits - the size in bits of the matrix element
266271 - 32 for f32, tf32
267272 - 16 for f16, int16, bf16
@@ -273,9 +278,10 @@ def XeVM_BlockStore2dOp
273278 ```mlir
274279 %base_width_c = arith.constant 64 : i32
275280 %base_height_c = arith.constant 8 : i32
281+ %base_pitch_c = arith.constant 64 : i32
276282 %x = arith.constant 0 : i32
277283 %y = arith.constant 0 : i32
278- xevm.blockstore2d %dst, %base_width_c, %base_height_c, %x, %y, %src
284+ xevm.blockstore2d %dst, %base_width_c, %base_height_c, %base_pitch_c, % x, %y, %src
279285 <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32,
280286 cache_control=#xevm.load_cache_control<Default>}>
281287 : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
@@ -377,8 +383,9 @@ def XeVM_PrefetchOp
377383def XeVM_BlockPrefetch2dOp
378384 : XeVM_Op<"blockprefetch2d">,
379385 Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
380- I32:$base_height, I32:$x, I32:$y, I32Attr:$elem_size_in_bits,
381- I32Attr:$tile_width, I32Attr:$tile_height, I32Attr:$v_blocks,
386+ I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
387+ I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
388+ I32Attr:$v_blocks,
382389 OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
383390
384391 let summary = "2D block prefetch";
@@ -387,7 +394,9 @@ def XeVM_BlockPrefetch2dOp
387394 The `xevm.blockprefetch2d` operation prefetches a two dimensional tile
388395 from a larger base matrix residing in global memory. The parameters are:
389396 $ptr - the base address of the base matrix containing the tile to prefetch
390- $base_width, $base_height - the shape of the base matrix in number of elements.
397+ $base_width, $base_height, $base_pitch - the shape of the base matrix.
398+ pitch is the physical stride between the first columns of the current row
399+ and the subsequent row. All units are in bytes.
391400 $x, $y, $tile_width, $tile_height - the starting offsets and shape of tile
392401 to prefetch in number of elements.
393402 $elem_size_in_bits - the size in bits of the matrix element
@@ -399,7 +408,7 @@ def XeVM_BlockPrefetch2dOp
399408
400409 Example:
401410 ```mlir
402- xevm.blockprefetch2d %ptr, %base_width, %base_height, %x, %y
411+ xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, % x, %y
403412 <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32,
404413 v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
405414 : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
0 commit comments