Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions test/TritonGEN/tritongen-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -428,19 +428,3 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
llvm.return
}

// -----

llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
// expected-error @+1 {{'triton_gen.simdblockread' op unsupported vector type}}
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16>
llvm.return
}

// -----

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val: vector<64xi16>) {
// expected-error @+1 {{'triton_gen.simdblockwrite' op unsupported vector type}}
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>)
llvm.return
}
21 changes: 15 additions & 6 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -241,20 +241,29 @@ llvm.func @triton_gen.dpas.bf16_accum(%c: vector<8xbf16>, %a : vector<8xi16>, %b

// CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(!llvm.ptr<3>) -> vector<2xi16> attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, will_return}

llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) {
llvm.func @triton_gen.sub_group_block_read(%ptr: !llvm.ptr<3>) {
// CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<3>) {
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<2xi16>
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<2xi16>
%ret = triton_gen.sub_group_block_read %ptr : !llvm.ptr<3> -> vector<2xi16>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(!llvm.ptr<3>, vector<2xi16>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, will_return}

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) {
llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) {
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(%arg0, %arg1) {{.*}} : (!llvm.ptr<3>, vector<2xi16>) -> ()
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<2xi16>)
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, vector<2xi16>
llvm.return
}

// -----

llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<1>, %val : i32) {
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<1>, %arg1: i32) {
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS1jj(%arg0, %arg1) {{.*}} : (!llvm.ptr<1>, i32) -> ()
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, i32
llvm.return
}
16 changes: 8 additions & 8 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,16 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base
llvm.return
}

llvm.func @triton_gen.simdblockread(%ptr : !llvm.ptr) {
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr) {
// CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<2xi16>
triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<2xi16>
llvm.func @triton_gen.sub_group_block_read(%ptr : !llvm.ptr<1>) {
// CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<1>) {
// CHECK-NEXT: triton_gen.sub_group_block_read %arg0 : !llvm.ptr<1> -> vector<2xi16>
triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<2xi16>
llvm.return
}

llvm.func @triton_gen.simdblockwrite(%ptr : !llvm.ptr, %val : vector<2xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr, %arg1: vector<2xi16>) {
// CHECK-NEXT: triton_gen.simdblockwrite %arg0, %arg1 : (!llvm.ptr, vector<2xi16>)
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr, vector<2xi16>)
llvm.func @triton_gen.sub_group_block_write(%ptr : !llvm.ptr<3>, %val : i32) {
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: i32) {
// CHECK-NEXT: triton_gen.sub_group_block_write %arg0, %arg1 : !llvm.ptr<3>, i32
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, i32
llvm.return
}
103 changes: 77 additions & 26 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -314,46 +314,97 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
let hasVerifier = 1;
}

def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr
)> {

let summary = "simd block read";
def TritonGEN_SubGroupBlockMemoryAccessElementType
: AnyTypeOf<[I8, I16, I32, I64],
"Valid sub-group block memory access element type">;

def TritonGEN_SubGroupBlockMemoryAccessType
: AnyTypeOf<[TritonGEN_SubGroupBlockMemoryAccessElementType,
FixedVectorOfLengthAndType<
[2, 4, 8],
[TritonGEN_SubGroupBlockMemoryAccessElementType]>,
// Vectors of length 16 only allowed for i8 for now.
FixedVectorOfLengthAndType<[16], [I8]>],
"Valid sub-group block memory access type">;

def TritonGEN_SubGroupBlockMemoryAccessPointerType
: Type<And<[LLVM_AnyPointer.predicate,
Or<[CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
".getAddressSpace() == " #
"static_cast<unsigned>(kCrossWorkgroup)">,
CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
".getAddressSpace() == " #
"static_cast<unsigned>(kWorkgroup)">]>]>,
"LLVM pointer in local or global OpenCL address space",
"::mlir::LLVM::LLVMPointerType">;

def TritonGEN_SubGroupBlockReadOp: TritonGEN_Op<"sub_group_block_read"> {
let summary = "Sub-group block read.";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder before merge: use spaces instead of tabs for indentation.


let description = [{
The `triton_gen.simdblockread` operation performs simd block read from
a start address without laneId offset. The parameters are:
$ptr - the base address to read data
The `triton_gen.sub_group_block_read` reads a scalar or vector for each
work-item in the sub-group from pointer `ptr` as a block operation.
The data is read strided, so the first value is read from:
```
ptr[sub_group_local_id]
```
and the second one is:
```
ptr[sub_group_local_id + sub_group_size]
```
etc.

`ptr` must be aligned to the size of the element type of `res`.

Example:
```mlir
%0 = triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<4xi32>
```
}];

let arguments = (ins
Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr);

let results = (outs TritonGEN_SubGroupBlockMemoryAccessType:$res);

let assemblyFormat = [{
operands ` ` attr-dict `:` functional-type(operands, results)
$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)
}];

let hasVerifier = 1;
}

def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val
)> {

def TritonGEN_SubGroupBlockWriteOp : TritonGEN_Op<"sub_group_block_write"> {
let summary = "simd block write";

let description = [{
The `triton_gen.simdblockwrite` operation performs simd block write to
a start address without laneId offset. The parameters are:
$ptr - the base address to be written
$val - the value vector to write
The `triton_gen.sub_group_block_write` writes a scalar or vector for each
work-item in the sub-group from pointer `ptr` as a block operation.
The data is read strided, so the first value is written to:
```
ptr[sub_group_local_id]
```
and the second one is:
```
ptr[sub_group_local_id + sub_group_size]
```
etc.

`ptr` must be aligned to the size of the element type of `res`.

Example:
```mlir
%0 = triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, vector<4xi32>
```
}];

let arguments = (ins
Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr,
TritonGEN_SubGroupBlockMemoryAccessType:$val);

let results = (outs);

let assemblyFormat = [{
operands ` ` attr-dict `:` `(` type(operands) `)`
$ptr `,` $val attr-dict `:` qualified(type($ptr)) `,` type($val)
}];

let hasVerifier = 1;
}

#endif // TRITONGEN_OPS
28 changes: 0 additions & 28 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,6 @@ template <typename Op> static LogicalResult verifyMatrixInput(Op op) {
return success();
}

static LogicalResult verifySIMDBlockTy(Operation *op, VectorType vecTy) {
unsigned numElems = vecTy.getNumElements();
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());

// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 &&
(elemTy.getWidth() != 8 || numElems != 16))
return op->emitOpError("unsupported vector type");

return success();
}

//===----------------------------------------------------------------------===//
// gen.sub_group_reduce
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -438,19 +426,3 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {

return success();
}

//===----------------------------------------------------------------------===//
// gen.simdblockread
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SIMDBlockReadOp::verify() {
return verifySIMDBlockTy(*this, getRes().getType());
}

//===----------------------------------------------------------------------===//
// gen.simdblockwrite
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SIMDBlockWriteOp::verify() {
return verifySIMDBlockTy(*this, getVal().getType());
}
75 changes: 43 additions & 32 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "mlir/Target/LLVMIR/TypeToLLVM.h"

#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/identity.h"
#include "llvm/IR/Attributes.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/ModRef.h"
Expand Down Expand Up @@ -935,69 +937,77 @@ struct TritonMatrix2DBlockPrefetchLowering
};

template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, TritonGEN::SIMDBlockReadOp,
TritonGEN::SIMDBlockWriteOp>::value>>
static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) {
OpType, TritonGEN::SubGroupBlockReadOp,
TritonGEN::SubGroupBlockWriteOp>::value>>
static std::string getSubGroupBlockManglingName(OpType op, Type type) {
constexpr bool isWrite =
std::is_same<OpType, TritonGEN::SIMDBlockWriteOp>::value;
std::is_same<OpType, TritonGEN::SubGroupBlockWriteOp>::value;
const LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
const unsigned numElems = vecTy.getNumElements();
// Note: OCL builtin name here differs from regular mangling.
std::string funcName = "intel_sub_group_block_";
if constexpr (isWrite)
funcName += "write";
else
funcName += "read";
funcName += "_u" + intel::getTypeMangling(vecTy.getElementType()) +
(numElems == 1 ? "" : std::to_string(numElems));
funcName =
"_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
std::to_string(ptrTy.getAddressSpace()) +
intel::getTypeMangling(vecTy.getElementType(), /*isUnsigned=*/true);
Type elementType =
TypeSwitch<Type, Type>(type)
.Case([](VectorType vecType) { return vecType.getElementType(); })
// Scalar case
.Default(llvm::identity<Type>());
const unsigned numElems =
TypeSwitch<Type, unsigned>(type)
.Case([](VectorType vecType) { return vecType.getNumElements(); })
// Scalar case
.Default(0u);
funcName += "_u" + intel::getTypeMangling(elementType) +
(numElems ? std::to_string(numElems) : "");
funcName = "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
std::to_string(ptrTy.getAddressSpace()) +
intel::getTypeMangling(elementType, /*isUnsigned=*/true);
if constexpr (isWrite)
funcName += intel::getTypeMangling(vecTy, /*isUnsigned=*/true);
funcName += intel::getTypeMangling(type, /*isUnsigned=*/true);
return funcName;
}

struct TritonSIMDBlockReadLowering
: public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockReadOp> {
struct TritonSubGroupBlockReadLowering
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupBlockReadOp> {
using ConvertOpToLLVMPattern<
TritonGEN::SIMDBlockReadOp>::ConvertOpToLLVMPattern;
TritonGEN::SubGroupBlockReadOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::SIMDBlockReadOp op, OpAdaptor adaptor,
matchAndRewrite(TritonGEN::SubGroupBlockReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getRes().getType();
Type type = op.getRes().getType();

std::string funcName = getSIMDBlockManglingName(op, vecTy);
std::string funcName = getSubGroupBlockManglingName(op, type);
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = noUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, vecTy, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});
rewriter, funcName, type, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});

rewriter.replaceOp(op, call.getResult());
return success();
}
};

struct TritonSIMDBlockWriteLowering
: public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockWriteOp> {
struct TritonSubGroupBlockWriteLowering
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupBlockWriteOp> {
using ConvertOpToLLVMPattern<
TritonGEN::SIMDBlockWriteOp>::ConvertOpToLLVMPattern;
TritonGEN::SubGroupBlockWriteOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::SIMDBlockWriteOp op, OpAdaptor adaptor,
matchAndRewrite(TritonGEN::SubGroupBlockWriteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getVal().getType();
Type type = op.getVal().getType();

std::string funcName = getSIMDBlockManglingName(op, vecTy);
std::string funcName = getSubGroupBlockManglingName(op, type);

auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
Expand All @@ -1006,7 +1016,7 @@ struct TritonSIMDBlockWriteLowering
auto funcAttrs = noUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, void_ty(ctx), {ptrTy, vecTy},
rewriter, funcName, void_ty(ctx), {ptrTy, type},
{op.getPtr(), op.getVal()}, {}, funcAttrs);

rewriter.replaceOp(op, call);
Expand Down Expand Up @@ -1071,12 +1081,13 @@ struct TritonGENToLLVMDialectInterface : public ConvertToLLVMPatternInterface {

void mlir::triton::populateTritonGENToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<
TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering,
TritonSubGroupReduceLowering, TritonSubGroupScanLowering,
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering,
TritonSIMDBlockReadLowering, TritonSIMDBlockWriteLowering>(converter);
patterns
.add<TritonGENSplitBarrierSignalLowering,
TritonGENSplitBarrierWaitLowering, TritonSubGroupReduceLowering,
TritonSubGroupScanLowering, TritonMatrixDPASLowering,
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
TritonMatrix2DBlockPrefetchLowering, TritonSubGroupBlockReadLowering,
TritonSubGroupBlockWriteLowering>(converter);
}

void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {
Expand Down
Loading