@@ -314,46 +314,96 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
314314 let hasVerifier = 1;
315315}
316316
317- def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
318- Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>,
319- Arguments<(ins
320- Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr
321- )> {
322-
323- let summary = "simd block read";
317+ def TritonGEN_SubGroupBlockMemoryAccessElementType
318+ : AnyTypeOf<[I8, I16, I32, I64],
319+ "Valid sub-group block memory access element type">;
320+
321+ def TritonGEN_SubGroupBlockMemoryAccessType
322+ : AnyTypeOf<[TritonGEN_SubGroupBlockMemoryAccessElementType,
323+ FixedVectorOfLengthAndType<[2, 4, 8],
324+ [TritonGEN_SubGroupBlockMemoryAccessElementType]>,
325+ // Vectors of length 16 only allowed for i8 for now.
326+ FixedVectorOfLengthAndType<[16], [I8]>],
327+ "Valid sub-group block memory access type">;
328+
329+ def TritonGEN_SubGroupBlockMemoryAccessPointerType
330+ : Type<And<[LLVM_AnyPointer.predicate,
331+ Or<[CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
332+ ".getAddressSpace() == " #
333+ "static_cast<unsigned>(kCrossWorkgroup)">,
334+ CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
335+ ".getAddressSpace() == " #
336+ "static_cast<unsigned>(kWorkgroup)">]>]>,
337+ "LLVM pointer in local or global OpenCL address space",
338+ "::mlir::LLVM::LLVMPointerType">;
339+
340+ def TritonGEN_SubGroupBlockReadOp: TritonGEN_Op<"sub_group_block_read"> {
341+ let summary = "Sub-group block read.";
324342
325343 let description = [{
326- The `triton_gen.simdblockread` operation performs simd block read from
327- a start address without laneId offset. The parameters are:
328- $ptr - the base address to read data
344+ The `triton_gen.sub_group_block_read` reads a scalar or vector for each
345+ work-item in the sub-group from pointer `ptr` as a block operation.
346+ The data is read strided, so the first value is read from:
347+ ```
348+ ptr[sub_group_local_id]
349+ ```
350+ and the second one is:
351+ ```
352+ ptr[sub_group_local_id + sub_group_size]
353+ ```
354+ etc.
355+
356+ `ptr` must be aligned to the size of the element type of `res`.
357+
358+ Example:
359+ ```mlir
360+ %0 = triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<4xi32>
361+ ```
329362 }];
330363
364+ let arguments = (ins
365+ Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr);
366+
367+ let results = (outs TritonGEN_SubGroupBlockMemoryAccessType:$res);
368+
331369 let assemblyFormat = [{
332- operands ` ` attr-dict `:` functional- type(operands, results )
370+ $ptr attr-dict `:` qualified( type($ptr)) `->` type($res )
333371 }];
334-
335- let hasVerifier = 1;
336372}
337373
338- def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
339- Arguments<(ins
340- Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
341- FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val
342- )> {
343-
374+ def TritonGEN_SubGroupBlockWriteOp : TritonGEN_Op<"sub_group_block_write"> {
344375 let summary = "simd block write";
345376
346377 let description = [{
347- The `triton_gen.simdblockwrite` operation performs simd block write to
348- a start address without laneId offset. The parameters are:
349- $ptr - the base address to be written
350- $val - the value vector to write
378+ The `triton_gen.sub_group_block_write` writes a scalar or vector for each
379+ work-item in the sub-group from pointer `ptr` as a block operation.
380+ The data is read strided, so the first value is written to:
381+ ```
382+ ptr[sub_group_local_id]
383+ ```
384+ and the second one is:
385+ ```
386+ ptr[sub_group_local_id + sub_group_size]
387+ ```
388+ etc.
389+
390+ `ptr` must be aligned to the size of the element type of `res`.
391+
392+ Example:
393+ ```mlir
394+ %0 = triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, vector<4xi32>
395+ ```
351396 }];
352397
398+ let arguments = (ins
399+ Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr,
400+ TritonGEN_SubGroupBlockMemoryAccessType:$val);
401+
402+ let results = (outs);
403+
353404 let assemblyFormat = [{
354- operands ` ` attr-dict `:` `(` type(operands) `)`
405+ $ptr `,` $val attr-dict `:` qualified( type($ptr)) `,` type($val)
355406 }];
356-
357- let hasVerifier = 1;
358407}
408+
359409#endif // TRITONGEN_OPS
0 commit comments