@@ -149,10 +149,13 @@ def TileZeroOp : AMX_Op<"tile_zero", [
149149 let summary = "tile zero operation";
150150 let description = [{
151151 Zeroes the destination tile, with the shape defined by the 2-dim
152- vector type of the result. This is eventually lowered into the
153- "tilezero" instruction with the corresponding tile configuration.
154- With memory-effects, each "tilezero" operation serves as a compilation
155- hint to use a separate tile register.
152+ vector type of the result.
153+
154+ The operation is eventually lowered into the "tilezero" instruction
155+ with the corresponding tile configuration.
156+
157+ With the write memory effect, each `amx.tile_zero` operation serves as
158+ a compilation hint to use a separate tile register.
156159
157160 Example:
158161
@@ -184,25 +187,53 @@ def TileZeroOp : AMX_Op<"tile_zero", [
184187
185188def TileLoadOp : AMX_Op<"tile_load", [
186189 AMXIntrinsicOpInterface,
187- MemoryEffects<[MemWrite]>
190+ MemoryEffects<[MemWrite]>,
191+ AttrSizedOperandSegments
188192 ]> {
189193 let summary = "tile load operation";
190194 let description = [{
191- Loads a tile from memory defined by a base and indices, with the
192- shape defined by the 2-dim vector type of the result. This is
193- eventually lowered into the "tileloadd" instruction with the
194- corresponding tile configuration. With memory-effects, each "tileload"
195- operation serves as a compilation hint to use a separate tile register.
195+ Loads a tile from memory defined by a `base` and `indices`, with the
196+ shape defined by the 2-dim vector type of the result.
197+ The tile's rows are populated by reading contiguous elements starting
198+ at the `base`. For each tile row, the `base` is incremented by `stride`
199+ number of elements.
200+
201+ The tile is loaded using the following indexing scheme:
202+
203+ ```
204+ for row in enumerate(tile_rows):
205+ mem_row = base[i0, i1, ..., iN + row * stride]
206+ for col in enumerate(tile_cols):
207+ tile[row, col] = mem_row[col]
208+ ```
209+
210+ If the `stride` is not provided, then the `base` buffer must be at least
211+ 2-dimensional, and the `stride` is automatically inferred and corresponds
212+ to the stride of the buffer's second innermost dimension.
213+
214+ The operation is eventually lowered into the "tileloadd" instruction
215+ with the corresponding tile configuration.
216+
217+ With the write memory effect, each `amx.tile_load` operation serves as
218+ a compilation hint to use a separate tile register.
196219
197220 Example:
198221
199222 ```mlir
223+ // Tile load from a 2-D memref with implicit stride.
200224 %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
225+
226+ // Tile load from a 1-D memref with explicit stride.
227+ %0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
201228 ```
202229 }];
203230 let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
204- Variadic<Index>:$indices);
231+ Variadic<Index>:$indices,
232+ Optional<Index>:$stride);
205233 let results = (outs AnyAMXTile:$res);
234+ let builders = [
235+ OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
236+ ];
206237 let extraClassDeclaration = [{
207238 MemRefType getMemRefType() {
208239 return ::llvm::cast<MemRefType>(getBase().getType());
@@ -219,30 +250,56 @@ def TileLoadOp : AMX_Op<"tile_load", [
219250 const ::mlir::LLVMTypeConverter &typeConverter,
220251 ::mlir::RewriterBase &rewriter);
221252 }];
222- let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
223- "type($base) `into` qualified(type($res))";
253+ let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict "
254+ "`:` type($base) `into` qualified(type($res))";
224255 let hasVerifier = 1;
225256}
226257
227258def TileStoreOp : AMX_Op<"tile_store", [
228- AMXIntrinsicOpInterface
259+ AMXIntrinsicOpInterface,
260+ AttrSizedOperandSegments
229261 ]> {
230262 let summary = "tile store operation";
231263 let description = [{
232- Stores a tile to memory defined by a base and indices, with the
233- shape defined by the 2-dim vector type of the value. This is
234- eventually lowered into the "tilestored" instruction with the
235- corresponding tile configuration.
264+ Stores a tile to memory defined by a `base` and `indices`, with the
265+ shape defined by the 2-dim vector type of the value.
266+ The tile's rows are written contiguously to the buffer starting at
267+ the `base`. For each tile row, the `base` is incremented by `stride`
268+ number of elements.
269+
270+ The tile is stored using the following indexing scheme:
271+
272+ ```
273+ for row in enumerate(tile_rows):
274+ mem_row = base[i0, i1, ..., iN + row * stride]
275+ for col in enumerate(tile_cols):
276+ mem_row[col] = tile[row, col]
277+ ```
278+
279+ If the `stride` is not provided, then the `base` buffer must be at least
280+ 2-dimensional, and the `stride` is automatically inferred and corresponds
281+ to the stride of the buffer's second innermost dimension.
282+
283+ The operation is eventually lowered into the "tilestored" instruction
284+ with the corresponding tile configuration.
236285
237286 Example:
238287
239288 ```mlir
289+ // Tile store to a 2-D memref with implicit stride.
240290 amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
291+
292+ // Tile store to a 1-D memref with explicit stride.
293+ amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
241294 ```
242295 }];
243296 let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
244297 Variadic<Index>:$indices,
245- AnyAMXTile:$val);
298+ AnyAMXTile:$val,
299+ Optional<Index>:$stride);
300+ let builders = [
301+ OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
302+ ];
246303 let extraClassDeclaration = [{
247304 MemRefType getMemRefType() {
248305 return ::llvm::cast<MemRefType>(getBase().getType());
@@ -259,8 +316,8 @@ def TileStoreOp : AMX_Op<"tile_store", [
259316 const ::mlir::LLVMTypeConverter &typeConverter,
260317 ::mlir::RewriterBase &rewriter);
261318 }];
262- let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
263- "type($base) `,` qualified(type($val))";
319+ let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )? "
320+ "attr-dict `:` type($base) `,` qualified(type($val))";
264321 let hasVerifier = 1;
265322}
266323
@@ -276,8 +333,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
276333 let description = [{
277334 Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
278335 into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
279- pairs of "bf16"). The operation is eventually lowered into the
280- "tdpbf16ps" instruction with the corresponding tile configuration.
336+ pairs of "bf16").
337+
338+ The operation is eventually lowered into the "tdpbf16ps" instruction with
339+ the corresponding tile configuration.
281340
282341 Example:
283342
@@ -330,9 +389,11 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
330389 into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
331390 combinations (4 bytes packed into dwords in the columns of both the
332391 source operand tiles; the zero or sign extension is specified with
333- the attributes and default to sign extended). The operation is eventually
334- lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
335- instructions with the corresponding tile configuration.
392+ the attributes and default to sign extended).
393+
394+ The operation is eventually lowered into one of the "tdpbssd",
395+ "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
396+ tile configuration.
336397
337398 Example:
338399
0 commit comments