Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
const TargetInfoBase &targetInfo,
const DataLayoutAnalysis *analysis = nullptr);

Type getElementTypeForStruct(TensorOrMemDesc type);
Type getElementTypeForStruct(triton::gpu::TensorOrMemDesc type);
Type convertTritonPointerType(triton::PointerType type);
Type convertTritonTensorType(RankedTensorType type,
const TargetInfoBase &targetInfo);
Type convertMemDescType(MemDescType type, const TargetInfoBase &targetInfo);
Type convertMemDescType(triton::gpu::MemDescType type,
const TargetInfoBase &targetInfo);
Type convertAsyncToken(triton::gpu::AsyncTokenType type);
};

Expand Down
10 changes: 6 additions & 4 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
Expand Down Expand Up @@ -1141,8 +1142,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
//
// Returns true on success.
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, Value shmemBase,
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
Expand Down Expand Up @@ -1310,13 +1311,14 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
}

SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
MemDescType srcTy, Type elemLlvmTy,
triton::gpu::MemDescType srcTy,
Type elemLlvmTy,
SharedMemoryObject smemObj,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);
Expand Down
4 changes: 0 additions & 4 deletions include/triton/Dialect/Triton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)

set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td)
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)

set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td)
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)
Expand Down
6 changes: 3 additions & 3 deletions include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class DotLike : public TraitBase<ConcreteType, DotLike> {
static LogicalResult verifyTrait(Operation *op) {
if (op->getNumOperands() < 3)
return op->emitOpError("expected at least 3 operands");
auto aTy = cast<TensorOrMemDesc>(op->getOperand(0).getType());
auto bTy = cast<TensorOrMemDesc>(op->getOperand(1).getType());
auto cTy = cast<TensorOrMemDesc>(op->getOperand(2).getType());
auto aTy = cast<ShapedType>(op->getOperand(0).getType());
auto bTy = cast<ShapedType>(op->getOperand(1).getType());
auto cTy = cast<ShapedType>(op->getOperand(2).getType());
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
auto cShape = cTy.getShape();
Expand Down
1 change: 0 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"


Expand Down
48 changes: 0 additions & 48 deletions include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,54 +92,6 @@ def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>;
// Any Type in Triton IR
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>;

// Memory descriptor type.
def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> {
let summary = "memory descriptor type (`::mlir::triton::MemDescType`) in Triton IR type system";

let description = [{
Memory descriptor contains a base pointer (scalar) and a descriptor of the memory.
If mutable memory is false that means the memory is constant and can only be allocated and stored once.
A constant memory allocation is different than a tensor as it can have multiple views and the descriptor
can be changed without changing the underlying memory.
}];

let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutable_memory
);
let extraClassDeclaration = [{
MemDescType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory());
}

bool hasRank() const { return true; }
}];
let builders = [
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false);
}]>,
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutableMemory
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory);
}]>
];
let hasCustomAssemblyFormat = 1;
}

// Result type of ExperimentalMakeTensorDescriptor
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system";
Expand Down
2 changes: 0 additions & 2 deletions include/triton/Dialect/Triton/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/Types.h.inc"

#include "triton/Dialect/Triton/IR/TypeInterfaces.h.inc"

namespace mlir {

namespace triton {
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/Attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"

#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
#include "triton/Dialect/TritonGPU/IR/AttrDefs.h.inc"

#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_
13 changes: 9 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonGPUTableGen)

set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(TritonGPUAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(TritonGPUAttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td)
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(TritonGPUTypeInterfacesIncGen)
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
// TritonGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Types.h"

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"

namespace mlir {
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#define TRITONGPU_ATTRDEFS

include "mlir/IR/AttrTypeBase.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"

//===----------------------------------------------------------------------===//
// TritonGPU Attribute Definitions
Expand Down
5 changes: 4 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef TRITON_GPU_DIALECT_INTERFACES_H
#define TRITON_GPU_DIALECT_INTERFACES_H

// clang-format off
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc"
#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc"
// clang-format on

#endif // TRITON_GPU_DIALECT_INTERFACES_H
29 changes: 15 additions & 14 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
Expand Down Expand Up @@ -95,7 +96,7 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [

let arguments = (
ins TT_PtrTensor:$src,
TT_MemDescType:$result,
TTG_MemDescType:$result,
Optional<I1Tensor>:$mask,
Optional<TT_Type>:$other,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
Expand Down Expand Up @@ -168,7 +169,7 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods<MemoryEf
}];
let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}];

let results = (outs TT_MemDescType:$result);
let results = (outs TTG_MemDescType:$result);
let hasFolder = 1;
let hasVerifier = 1;
}
Expand All @@ -191,9 +192,9 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree<SharedM
operand.
}];

let arguments = (ins TT_MemDescType:$src);
let arguments = (ins TTG_MemDescType:$src);

// Use qualified() otherwise "!tt.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}];
}

Expand All @@ -212,12 +213,12 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
Then in Python syntax, the subview covers input[1][0:4][4:8].
}];
let arguments = (
ins TT_MemDescType:$src, Variadic<I32>:$offsets);
ins TTG_MemDescType:$src, Variadic<I32>:$offsets);

// Use qualified() otherwise "!tt.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];

let results = (outs TT_MemDescType:$result);
let results = (outs TTG_MemDescType:$result);

let hasVerifier = 1;
}
Expand All @@ -233,14 +234,14 @@ def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
representing a transposed view of the buffer.
}];

let arguments = (ins TT_MemDescType:$src, Variadic<I32>:$order);
let arguments = (ins TTG_MemDescType:$src, Variadic<I32>:$order);

let arguments = (
ins TT_MemDescType:$src,
ins TTG_MemDescType:$src,
DenseI32ArrayAttr:$order
);

let results = (outs TT_MemDescType:$result);
let results = (outs TTG_MemDescType:$result);

let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))";

Expand All @@ -253,15 +254,15 @@ def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffe
let description = [{
Load a tensor from the local memory descriptor into a distributed tensor.
}];
let arguments = (ins TT_MemDescType:$src, Optional<TTG_AsyncToken> :$token);
let arguments = (ins TTG_MemDescType:$src, Optional<TTG_AsyncToken> :$token);

let builders = [
OpBuilder<(ins "Type":$retType, "Value":$src),
[{
build($_builder, $_state, retType, src, /*token=*/static_cast<mlir::Value>(nullptr));
}]>];

// Use qualified() otherwise "!tt.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}];

let results = (outs TT_Tensor:$result);
Expand All @@ -273,10 +274,10 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
let description = [{
Store a distributed tensor into a buffer in local memory.
}];
let arguments = (ins TT_Tensor:$src, TT_MemDescType:$dst);
let arguments = (ins TT_Tensor:$src, TTG_MemDescType:$dst);

let hasVerifier = 1;
// Use qualified() otherwise "!tt.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{
$src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst))
}];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#ifndef TRITON_TYPE_INTERFACES
#define TRITON_TYPE_INTERFACES
#ifndef TRITON_GPU_TYPE_INTERFACES
#define TRITON_GPU_TYPE_INTERFACES

include "mlir/IR/OpBase.td"

// Interface dynamically attached to RankedTensorType and MemDescType.
def TT_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> {
let cppNamespace = "::mlir";
def TTG_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> {
let cppNamespace = "::mlir::triton::gpu";
let methods = [
InterfaceMethod<"Returns the encoding of the tensor or memory descriptor",
"mlir::Attribute", "getEncoding", (ins)>,
Expand All @@ -17,8 +17,7 @@ def TT_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> {
"int64_t", "getRank", (ins)>,
InterfaceMethod<"Returns the element type bit width",
"int64_t", "getElementTypeBitWidth", (ins)>,

];
}

#endif // TRITON_TYPE_INTERFACES
#endif // TRITON_GPU_TYPE_INTERFACES
Loading