Skip to content

Commit 7551a90

Browse files
authored
[XPU][TritonGEN] Revamp SIMD block memory access operations (#2756)
Revamp `tritongen.simdblock[read|write]` operations: - Rename to `tritongen.sub_group_block_[read|write]` - Implement type verification in signature - Represent scalar block memory accesses with a scalar type instead of `vector<1xty>` - Revamp ASM format --------- Signed-off-by: victor-eds <[email protected]>
1 parent 61f5381 commit 7551a90

File tree

8 files changed

+147
-121
lines changed

8 files changed

+147
-121
lines changed

test/TritonGEN/tritongen-invalid.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -428,19 +428,3 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei
428428
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
429429
llvm.return
430430
}
431-
432-
// -----
433-
434-
llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
435-
// expected-error @+1 {{'triton_gen.simdblockread' op unsupported vector type}}
436-
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16>
437-
llvm.return
438-
}
439-
440-
// -----
441-
442-
llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val: vector<64xi16>) {
443-
// expected-error @+1 {{'triton_gen.simdblockwrite' op unsupported vector type}}
444-
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>)
445-
llvm.return
446-
}

test/TritonGEN/tritongen-to-llvm.mlir

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,20 +241,29 @@ llvm.func @triton_gen.dpas.bf16_accum(%c: vector<8xbf16>, %a : vector<8xi16>, %b
241241

242242
// CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(!llvm.ptr<3>) -> vector<2xi16> attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, will_return}
243243

244-
llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
245-
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) {
244+
llvm.func @triton_gen.sub_group_block_read(%ptr: !llvm.ptr<3>) {
245+
// CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<3>) {
246246
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<2xi16>
247-
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<2xi16>
247+
%ret = triton_gen.sub_group_block_read %ptr : !llvm.ptr<3> -> vector<2xi16>
248248
llvm.return
249249
}
250250

251251
// -----
252252

253253
// CHECK: llvm.func spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(!llvm.ptr<3>, vector<2xi16>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, will_return}
254254

255-
llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
256-
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) {
255+
llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
256+
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) {
257257
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(%arg0, %arg1) {{.*}} : (!llvm.ptr<3>, vector<2xi16>) -> ()
258-
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<2xi16>)
258+
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, vector<2xi16>
259+
llvm.return
260+
}
261+
262+
// -----
263+
264+
llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<1>, %val : i32) {
265+
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<1>, %arg1: i32) {
266+
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS1jj(%arg0, %arg1) {{.*}} : (!llvm.ptr<1>, i32) -> ()
267+
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, i32
259268
llvm.return
260269
}

test/TritonGEN/tritongen.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,16 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base
125125
llvm.return
126126
}
127127

128-
llvm.func @triton_gen.simdblockread(%ptr : !llvm.ptr) {
129-
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr) {
130-
// CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<2xi16>
131-
triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<2xi16>
128+
llvm.func @triton_gen.sub_group_block_read(%ptr : !llvm.ptr<1>) {
129+
// CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<1>) {
130+
// CHECK-NEXT: triton_gen.sub_group_block_read %arg0 : !llvm.ptr<1> -> vector<2xi16>
131+
triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<2xi16>
132132
llvm.return
133133
}
134134

135-
llvm.func @triton_gen.simdblockwrite(%ptr : !llvm.ptr, %val : vector<2xi16>) {
136-
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr, %arg1: vector<2xi16>) {
137-
// CHECK-NEXT: triton_gen.simdblockwrite %arg0, %arg1 : (!llvm.ptr, vector<2xi16>)
138-
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr, vector<2xi16>)
135+
llvm.func @triton_gen.sub_group_block_write(%ptr : !llvm.ptr<3>, %val : i32) {
136+
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: i32) {
137+
// CHECK-NEXT: triton_gen.sub_group_block_write %arg0, %arg1 : !llvm.ptr<3>, i32
138+
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, i32
139139
llvm.return
140140
}

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -314,46 +314,96 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
314314
let hasVerifier = 1;
315315
}
316316

317-
def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
318-
Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>,
319-
Arguments<(ins
320-
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr
321-
)> {
322-
323-
let summary = "simd block read";
317+
def TritonGEN_SubGroupBlockMemoryAccessElementType
318+
: AnyTypeOf<[I8, I16, I32, I64],
319+
"Valid sub-group block memory access element type">;
320+
321+
def TritonGEN_SubGroupBlockMemoryAccessType
322+
: AnyTypeOf<[TritonGEN_SubGroupBlockMemoryAccessElementType,
323+
FixedVectorOfLengthAndType<[2, 4, 8],
324+
[TritonGEN_SubGroupBlockMemoryAccessElementType]>,
325+
// Vectors of length 16 only allowed for i8 for now.
326+
FixedVectorOfLengthAndType<[16], [I8]>],
327+
"Valid sub-group block memory access type">;
328+
329+
def TritonGEN_SubGroupBlockMemoryAccessPointerType
330+
: Type<And<[LLVM_AnyPointer.predicate,
331+
Or<[CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
332+
".getAddressSpace() == " #
333+
"static_cast<unsigned>(kCrossWorkgroup)">,
334+
CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
335+
".getAddressSpace() == " #
336+
"static_cast<unsigned>(kWorkgroup)">]>]>,
337+
"LLVM pointer in local or global OpenCL address space",
338+
"::mlir::LLVM::LLVMPointerType">;
339+
340+
def TritonGEN_SubGroupBlockReadOp: TritonGEN_Op<"sub_group_block_read"> {
341+
let summary = "Sub-group block read.";
324342

325343
let description = [{
326-
The `triton_gen.simdblockread` operation performs simd block read from
327-
a start address without laneId offset. The parameters are:
328-
$ptr - the base address to read data
344+
The `triton_gen.sub_group_block_read` reads a scalar or vector for each
345+
work-item in the sub-group from pointer `ptr` as a block operation.
346+
The data is read strided, so the first value is read from:
347+
```
348+
ptr[sub_group_local_id]
349+
```
350+
and the second one is:
351+
```
352+
ptr[sub_group_local_id + sub_group_size]
353+
```
354+
etc.
355+
356+
`ptr` must be aligned to the size of the element type of `res`.
357+
358+
Example:
359+
```mlir
360+
%0 = triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<4xi32>
361+
```
329362
}];
330363

364+
let arguments = (ins
365+
Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr);
366+
367+
let results = (outs TritonGEN_SubGroupBlockMemoryAccessType:$res);
368+
331369
let assemblyFormat = [{
332-
operands ` ` attr-dict `:` functional-type(operands, results)
370+
$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)
333371
}];
334-
335-
let hasVerifier = 1;
336372
}
337373

338-
def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
339-
Arguments<(ins
340-
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
341-
FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val
342-
)> {
343-
374+
def TritonGEN_SubGroupBlockWriteOp : TritonGEN_Op<"sub_group_block_write"> {
344375
let summary = "simd block write";
345376

346377
let description = [{
347-
The `triton_gen.simdblockwrite` operation performs simd block write to
348-
a start address without laneId offset. The parameters are:
349-
$ptr - the base address to be written
350-
$val - the value vector to write
378+
The `triton_gen.sub_group_block_write` writes a scalar or vector for each
379+
work-item in the sub-group from pointer `ptr` as a block operation.
380+
The data is read strided, so the first value is written to:
381+
```
382+
ptr[sub_group_local_id]
383+
```
384+
and the second one is:
385+
```
386+
ptr[sub_group_local_id + sub_group_size]
387+
```
388+
etc.
389+
390+
`ptr` must be aligned to the size of the element type of `res`.
391+
392+
Example:
393+
```mlir
394+
%0 = triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, vector<4xi32>
395+
```
351396
}];
352397

398+
let arguments = (ins
399+
Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr,
400+
TritonGEN_SubGroupBlockMemoryAccessType:$val);
401+
402+
let results = (outs);
403+
353404
let assemblyFormat = [{
354-
operands ` ` attr-dict `:` `(` type(operands) `)`
405+
$ptr `,` $val attr-dict `:` qualified(type($ptr)) `,` type($val)
355406
}];
356-
357-
let hasVerifier = 1;
358407
}
408+
359409
#endif // TRITONGEN_OPS

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,6 @@ template <typename Op> static LogicalResult verifyMatrixInput(Op op) {
4848
return success();
4949
}
5050

51-
static LogicalResult verifySIMDBlockTy(Operation *op, VectorType vecTy) {
52-
unsigned numElems = vecTy.getNumElements();
53-
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());
54-
55-
// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
56-
if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 &&
57-
(elemTy.getWidth() != 8 || numElems != 16))
58-
return op->emitOpError("unsupported vector type");
59-
60-
return success();
61-
}
62-
6351
//===----------------------------------------------------------------------===//
6452
// gen.sub_group_reduce
6553
//===----------------------------------------------------------------------===//
@@ -438,19 +426,3 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {
438426

439427
return success();
440428
}
441-
442-
//===----------------------------------------------------------------------===//
443-
// gen.simdblockread
444-
//===----------------------------------------------------------------------===//
445-
446-
LogicalResult TritonGEN::SIMDBlockReadOp::verify() {
447-
return verifySIMDBlockTy(*this, getRes().getType());
448-
}
449-
450-
//===----------------------------------------------------------------------===//
451-
// gen.simdblockwrite
452-
//===----------------------------------------------------------------------===//
453-
454-
LogicalResult TritonGEN::SIMDBlockWriteOp::verify() {
455-
return verifySIMDBlockTy(*this, getVal().getType());
456-
}

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
2929

3030
#include "llvm/ADT/StringRef.h"
31+
#include "llvm/ADT/TypeSwitch.h"
32+
#include "llvm/ADT/identity.h"
3133
#include "llvm/IR/Attributes.h"
3234
#include "llvm/Support/ErrorHandling.h"
3335
#include "llvm/Support/ModRef.h"
@@ -935,69 +937,77 @@ struct TritonMatrix2DBlockPrefetchLowering
935937
};
936938

937939
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
938-
OpType, TritonGEN::SIMDBlockReadOp,
939-
TritonGEN::SIMDBlockWriteOp>::value>>
940-
static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) {
940+
OpType, TritonGEN::SubGroupBlockReadOp,
941+
TritonGEN::SubGroupBlockWriteOp>::value>>
942+
static std::string getSubGroupBlockManglingName(OpType op, Type type) {
941943
constexpr bool isWrite =
942-
std::is_same<OpType, TritonGEN::SIMDBlockWriteOp>::value;
944+
std::is_same<OpType, TritonGEN::SubGroupBlockWriteOp>::value;
943945
const LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
944-
const unsigned numElems = vecTy.getNumElements();
945946
// Note: OCL builtin name here differs from regular mangling.
946947
std::string funcName = "intel_sub_group_block_";
947948
if constexpr (isWrite)
948949
funcName += "write";
949950
else
950951
funcName += "read";
951-
funcName += "_u" + intel::getTypeMangling(vecTy.getElementType()) +
952-
(numElems == 1 ? "" : std::to_string(numElems));
953-
funcName =
954-
"_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
955-
std::to_string(ptrTy.getAddressSpace()) +
956-
intel::getTypeMangling(vecTy.getElementType(), /*isUnsigned=*/true);
952+
Type elementType =
953+
TypeSwitch<Type, Type>(type)
954+
.Case([](VectorType vecType) { return vecType.getElementType(); })
955+
// Scalar case
956+
.Default(llvm::identity<Type>());
957+
const unsigned numElems =
958+
TypeSwitch<Type, unsigned>(type)
959+
.Case([](VectorType vecType) { return vecType.getNumElements(); })
960+
// Scalar case
961+
.Default(0u);
962+
funcName += "_u" + intel::getTypeMangling(elementType) +
963+
(numElems ? std::to_string(numElems) : "");
964+
funcName = "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
965+
std::to_string(ptrTy.getAddressSpace()) +
966+
intel::getTypeMangling(elementType, /*isUnsigned=*/true);
957967
if constexpr (isWrite)
958-
funcName += intel::getTypeMangling(vecTy, /*isUnsigned=*/true);
968+
funcName += intel::getTypeMangling(type, /*isUnsigned=*/true);
959969
return funcName;
960970
}
961971

962-
struct TritonSIMDBlockReadLowering
963-
: public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockReadOp> {
972+
struct TritonSubGroupBlockReadLowering
973+
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupBlockReadOp> {
964974
using ConvertOpToLLVMPattern<
965-
TritonGEN::SIMDBlockReadOp>::ConvertOpToLLVMPattern;
975+
TritonGEN::SubGroupBlockReadOp>::ConvertOpToLLVMPattern;
966976

967977
LogicalResult
968-
matchAndRewrite(TritonGEN::SIMDBlockReadOp op, OpAdaptor adaptor,
978+
matchAndRewrite(TritonGEN::SubGroupBlockReadOp op, OpAdaptor adaptor,
969979
ConversionPatternRewriter &rewriter) const override {
970980
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
971-
VectorType vecTy = op.getRes().getType();
981+
Type type = op.getRes().getType();
972982

973-
std::string funcName = getSIMDBlockManglingName(op, vecTy);
983+
std::string funcName = getSubGroupBlockManglingName(op, type);
974984
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
975985
/*other=*/LLVM::ModRefInfo::NoModRef,
976986
/*argMem=*/LLVM::ModRefInfo::Ref,
977987
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
978988
auto funcAttrs = noUnwindWillReturnAttrs;
979989
funcAttrs.memEffectsAttr = memAttr;
980990
LLVM::CallOp call = createDeviceFunctionCall(
981-
rewriter, funcName, vecTy, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});
991+
rewriter, funcName, type, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});
982992

983993
rewriter.replaceOp(op, call.getResult());
984994
return success();
985995
}
986996
};
987997

988-
struct TritonSIMDBlockWriteLowering
989-
: public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockWriteOp> {
998+
struct TritonSubGroupBlockWriteLowering
999+
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupBlockWriteOp> {
9901000
using ConvertOpToLLVMPattern<
991-
TritonGEN::SIMDBlockWriteOp>::ConvertOpToLLVMPattern;
1001+
TritonGEN::SubGroupBlockWriteOp>::ConvertOpToLLVMPattern;
9921002

9931003
LogicalResult
994-
matchAndRewrite(TritonGEN::SIMDBlockWriteOp op, OpAdaptor adaptor,
1004+
matchAndRewrite(TritonGEN::SubGroupBlockWriteOp op, OpAdaptor adaptor,
9951005
ConversionPatternRewriter &rewriter) const override {
9961006
MLIRContext *ctx = rewriter.getContext();
9971007
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
998-
VectorType vecTy = op.getVal().getType();
1008+
Type type = op.getVal().getType();
9991009

1000-
std::string funcName = getSIMDBlockManglingName(op, vecTy);
1010+
std::string funcName = getSubGroupBlockManglingName(op, type);
10011011

10021012
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
10031013
/*other=*/LLVM::ModRefInfo::NoModRef,
@@ -1006,7 +1016,7 @@ struct TritonSIMDBlockWriteLowering
10061016
auto funcAttrs = noUnwindWillReturnAttrs;
10071017
funcAttrs.memEffectsAttr = memAttr;
10081018
LLVM::CallOp call = createDeviceFunctionCall(
1009-
rewriter, funcName, void_ty(ctx), {ptrTy, vecTy},
1019+
rewriter, funcName, void_ty(ctx), {ptrTy, type},
10101020
{op.getPtr(), op.getVal()}, {}, funcAttrs);
10111021

10121022
rewriter.replaceOp(op, call);
@@ -1071,12 +1081,13 @@ struct TritonGENToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
10711081

10721082
void mlir::triton::populateTritonGENToLLVMConversionPatterns(
10731083
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1074-
patterns.add<
1075-
TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering,
1076-
TritonSubGroupReduceLowering, TritonSubGroupScanLowering,
1077-
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
1078-
TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering,
1079-
TritonSIMDBlockReadLowering, TritonSIMDBlockWriteLowering>(converter);
1084+
patterns
1085+
.add<TritonGENSplitBarrierSignalLowering,
1086+
TritonGENSplitBarrierWaitLowering, TritonSubGroupReduceLowering,
1087+
TritonSubGroupScanLowering, TritonMatrixDPASLowering,
1088+
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
1089+
TritonMatrix2DBlockPrefetchLowering, TritonSubGroupBlockReadLowering,
1090+
TritonSubGroupBlockWriteLowering>(converter);
10801091
}
10811092

10821093
void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {

0 commit comments

Comments
 (0)