Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include(ExternalProject)

set(CMAKE_INCLUDE_CURRENT_DIR ON)

project(triton CXX)
project(triton CXX C)
include(CTest)

list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);

// Check if MFMA layout can be converted to the dot operand
// layout using warp shuffle.
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
RankedTensorType dstTy);

// TODO: Move utility functions that belong to ConvertLayoutOp to class
// ConvertLayoutOpHelper in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
Expand Down
5 changes: 3 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
const TargetInfoBase &targetInfo,
const DataLayoutAnalysis *analysis = nullptr);

Type getElementTypeForStruct(TensorOrMemDesc type);
Type getElementTypeForStruct(triton::gpu::TensorOrMemDesc type);
Type convertTritonPointerType(triton::PointerType type);
Type convertTritonTensorType(RankedTensorType type,
const TargetInfoBase &targetInfo);
Type convertMemDescType(MemDescType type, const TargetInfoBase &targetInfo);
Type convertMemDescType(triton::gpu::MemDescType type,
const TargetInfoBase &targetInfo);
Type convertAsyncToken(triton::gpu::AsyncTokenType type);
};

Expand Down
10 changes: 6 additions & 4 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
Expand Down Expand Up @@ -1141,8 +1142,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
//
// Returns true on success.
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, Value shmemBase,
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
Expand Down Expand Up @@ -1310,13 +1311,14 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
}

SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
MemDescType srcTy, Type elemLlvmTy,
triton::gpu::MemDescType srcTy,
Type elemLlvmTy,
SharedMemoryObject smemObj,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);
Expand Down
4 changes: 0 additions & 4 deletions include/triton/Dialect/Triton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)

set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td)
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)

set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td)
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)
Expand Down
6 changes: 3 additions & 3 deletions include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class DotLike : public TraitBase<ConcreteType, DotLike> {
static LogicalResult verifyTrait(Operation *op) {
if (op->getNumOperands() < 3)
return op->emitOpError("expected at least 3 operands");
auto aTy = cast<TensorOrMemDesc>(op->getOperand(0).getType());
auto bTy = cast<TensorOrMemDesc>(op->getOperand(1).getType());
auto cTy = cast<TensorOrMemDesc>(op->getOperand(2).getType());
auto aTy = cast<ShapedType>(op->getOperand(0).getType());
auto bTy = cast<ShapedType>(op->getOperand(1).getType());
auto cTy = cast<ShapedType>(op->getOperand(2).getType());
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
auto cShape = cTy.getShape();
Expand Down
1 change: 0 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"


Expand Down
48 changes: 0 additions & 48 deletions include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,54 +92,6 @@ def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>;
// Any Type in Triton IR
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>;

// Memory descriptor type.
def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> {
let summary = "memory descriptor type (`::mlir::triton::MemDescType`) in Triton IR type system";

let description = [{
Memory descriptor contains a base pointer (scalar) and a descriptor of the memory.
If mutable memory is false that means the memory is constant and can only be allocated and stored once.
A constant memory allocation is different than a tensor as it can have multiple views and the descriptor
can be changed without changing the underlying memory.
}];

let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutable_memory
);
let extraClassDeclaration = [{
MemDescType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory());
}

bool hasRank() const { return true; }
}];
let builders = [
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false);
}]>,
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutableMemory
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory);
}]>
];
let hasCustomAssemblyFormat = 1;
}

// Result type of ExperimentalMakeTensorDescriptor
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system";
Expand Down
2 changes: 0 additions & 2 deletions include/triton/Dialect/Triton/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/Types.h.inc"

#include "triton/Dialect/Triton/IR/TypeInterfaces.h.inc"

namespace mlir {

namespace triton {
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/Attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"

#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
#include "triton/Dialect/TritonGPU/IR/AttrDefs.h.inc"

#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_
13 changes: 9 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonGPUTableGen)

set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(TritonGPUAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(TritonGPUAttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td)
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(TritonGPUTypeInterfacesIncGen)
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
// TritonGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Types.h"

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"

namespace mlir {
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#define TRITONGPU_ATTRDEFS

include "mlir/IR/AttrTypeBase.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"

//===----------------------------------------------------------------------===//
// TritonGPU Attribute Definitions
Expand Down
5 changes: 4 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef TRITON_GPU_DIALECT_INTERFACES_H
#define TRITON_GPU_DIALECT_INTERFACES_H

// clang-format off
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc"
#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc"
// clang-format on

#endif // TRITON_GPU_DIALECT_INTERFACES_H
29 changes: 15 additions & 14 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
Expand Down Expand Up @@ -95,7 +96,7 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [

let arguments = (
ins TT_PtrTensor:$src,
TT_MemDescType:$result,
TTG_MemDescType:$result,
Optional<I1Tensor>:$mask,
Optional<TT_Type>:$other,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
Expand Down Expand Up @@ -168,7 +169,7 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods<MemoryEf
}];
let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}];

let results = (outs TT_MemDescType:$result);
let results = (outs TTG_MemDescType:$result);
let hasFolder = 1;
let hasVerifier = 1;
}
Expand All @@ -191,9 +192,9 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree<SharedM
operand.
}];

let arguments = (ins TT_MemDescType:$src);
let arguments = (ins TTG_MemDescType:$src);

// Use qualified() otherwise "!tt.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}];
}

Expand All @@ -212,12 +213,12 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
Then in Python syntax, the subview covers input[1][0:4][4:8].
}];
let arguments = (
ins TT_MemDescType:$src, Variadic<I32>:$offsets);
ins TTG_MemDescType:$src, Variadic<I32>:$offsets);

// Use qualified() otherwise "!tt.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];

let results = (outs TT_MemDescType:$result);
let results = (outs TTG_MemDescType:$result);

let hasVerifier = 1;
}
Expand All @@ -233,14 +234,14 @@ def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
representing a transposed view of the buffer.
}];

let arguments = (ins TT_MemDescType:$src, Variadic<I32>:$order);
let arguments = (ins TTG_MemDescType:$src, Variadic<I32>:$order);

let arguments = (
ins TT_MemDescType:$src,
ins TTG_MemDescType:$src,
DenseI32ArrayAttr:$order
);

let results = (outs TT_MemDescType:$result);
let results = (outs TTG_MemDescType:$result);

let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))";

Expand All @@ -253,15 +254,15 @@ def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffe
let description = [{
Load a tensor from the local memory descriptor into a distributed tensor.
}];
let arguments = (ins TT_MemDescType:$src, Optional<TTG_AsyncToken> :$token);
let arguments = (ins TTG_MemDescType:$src, Optional<TTG_AsyncToken> :$token);

let builders = [
OpBuilder<(ins "Type":$retType, "Value":$src),
[{
build($_builder, $_state, retType, src, /*token=*/static_cast<mlir::Value>(nullptr));
}]>];

// Use qualified() otherwise "!tt.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}];

let results = (outs TT_Tensor:$result);
Expand All @@ -273,10 +274,10 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
let description = [{
Store a distributed tensor into a buffer in local memory.
}];
let arguments = (ins TT_Tensor:$src, TT_MemDescType:$dst);
let arguments = (ins TT_Tensor:$src, TTG_MemDescType:$dst);

let hasVerifier = 1;
// Use qualified() otherwise "!tt.memdesc<X>" is printed as "<X>".
// Use qualified() otherwise "!triton_gpu.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{
$src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst))
}];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#ifndef TRITON_TYPE_INTERFACES
#define TRITON_TYPE_INTERFACES
#ifndef TRITON_GPU_TYPE_INTERFACES
#define TRITON_GPU_TYPE_INTERFACES

include "mlir/IR/OpBase.td"

// Interface dynamically attached to RankedTensorType and MemDescType.
def TT_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> {
let cppNamespace = "::mlir";
def TTG_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> {
let cppNamespace = "::mlir::triton::gpu";
let methods = [
InterfaceMethod<"Returns the encoding of the tensor or memory descriptor",
"mlir::Attribute", "getEncoding", (ins)>,
Expand All @@ -17,8 +17,7 @@ def TT_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> {
"int64_t", "getRank", (ins)>,
InterfaceMethod<"Returns the element type bit width",
"int64_t", "getElementTypeBitWidth", (ins)>,

];
}

#endif // TRITON_TYPE_INTERFACES
#endif // TRITON_GPU_TYPE_INTERFACES
Loading