Skip to content

Commit e5545af

Browse files
Merge commit '4ae95e70cd81eb62f89ec530605440b85e799dee'
2 parents 5bbce9e + 4ae95e7 commit e5545af

File tree

84 files changed

+1275
-1156
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+1275
-1156
lines changed

include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
1818
const TargetInfoBase &targetInfo,
1919
const DataLayoutAnalysis *analysis = nullptr);
2020

21-
Type getElementTypeForStruct(TensorOrMemDesc type);
21+
Type getElementTypeForStruct(triton::gpu::TensorOrMemDesc type);
2222
Type convertTritonPointerType(triton::PointerType type);
2323
Type convertTritonTensorType(RankedTensorType type,
2424
const TargetInfoBase &targetInfo);
25-
Type convertMemDescType(MemDescType type, const TargetInfoBase &targetInfo);
25+
Type convertMemDescType(triton::gpu::MemDescType type,
26+
const TargetInfoBase &targetInfo);
2627
Type convertAsyncToken(triton::gpu::AsyncTokenType type);
2728
};
2829

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "triton/Dialect/Triton/IR/Utility.h"
1515
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1616
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
17+
#include "triton/Dialect/TritonGPU/IR/Types.h"
1718
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1819
#include "triton/Tools/LinearLayout.h"
1920
#include "triton/Tools/StrUtil.h"
@@ -1141,8 +1142,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
11411142
//
11421143
// Returns true on success.
11431144
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
1144-
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
1145-
std::optional<int32_t> maxVecElems, Value shmemBase,
1145+
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
1146+
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
11461147
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
11471148
const TargetInfoBase &target,
11481149
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
@@ -1310,13 +1311,14 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
13101311
}
13111312

13121313
SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
1313-
MemDescType srcTy, Type elemLlvmTy,
1314+
triton::gpu::MemDescType srcTy,
1315+
Type elemLlvmTy,
13141316
SharedMemoryObject smemObj,
13151317
Location loc, RewriterBase &rewriter,
13161318
const TargetInfoBase &target);
13171319

13181320
void storeDistributedToShared(
1319-
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
1321+
triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
13201322
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
13211323
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
13221324
std::pair<size_t, Type> *const llvmOpCount = nullptr);

include/triton/Dialect/Triton/IR/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
2020
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
2121
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
2222

23-
set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td)
24-
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
25-
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
26-
2723
set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td)
2824
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
2925
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ class DotLike : public TraitBase<ConcreteType, DotLike> {
6969
static LogicalResult verifyTrait(Operation *op) {
7070
if (op->getNumOperands() < 3)
7171
return op->emitOpError("expected at least 3 operands");
72-
auto aTy = cast<TensorOrMemDesc>(op->getOperand(0).getType());
73-
auto bTy = cast<TensorOrMemDesc>(op->getOperand(1).getType());
74-
auto cTy = cast<TensorOrMemDesc>(op->getOperand(2).getType());
72+
auto aTy = cast<ShapedType>(op->getOperand(0).getType());
73+
auto bTy = cast<ShapedType>(op->getOperand(1).getType());
74+
auto cTy = cast<ShapedType>(op->getOperand(2).getType());
7575
auto aShape = aTy.getShape();
7676
auto bShape = bTy.getShape();
7777
auto cShape = cTy.getShape();

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
1313
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
1414
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
1515
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
16-
include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
1716
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
1817

1918

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -92,54 +92,6 @@ def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>;
9292
// Any Type in Triton IR
9393
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>;
9494

95-
// Memory descriptor type.
96-
def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> {
97-
let summary = "memory descriptor type (`::mlir::triton::MemDescType`) in Triton IR type system";
98-
99-
let description = [{
100-
Memory descriptor contains a base pointer (scalar) and a descriptor of the memory.
101-
If mutable memory is false that means the memory is constant and can only be allocated and stored once.
102-
A constant memory allocation is different than a tensor as it can have multiple views and the descriptor
103-
can be changed without changing the underlying memory.
104-
}];
105-
106-
let parameters = (ins
107-
ArrayRefParameter<"int64_t">:$shape,
108-
"Type":$elementType,
109-
"Attribute":$encoding,
110-
"Attribute":$memorySpace,
111-
"bool":$mutable_memory
112-
);
113-
let extraClassDeclaration = [{
114-
MemDescType cloneWith(std::optional<ArrayRef<int64_t>> shape,
115-
Type elementType) const {
116-
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory());
117-
}
118-
119-
bool hasRank() const { return true; }
120-
}];
121-
let builders = [
122-
TypeBuilderWithInferredContext<(ins
123-
"llvm::ArrayRef<int64_t>":$shape,
124-
"Type":$elementType,
125-
"Attribute":$encoding,
126-
"Attribute":$memorySpace
127-
), [{
128-
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false);
129-
}]>,
130-
TypeBuilderWithInferredContext<(ins
131-
"llvm::ArrayRef<int64_t>":$shape,
132-
"Type":$elementType,
133-
"Attribute":$encoding,
134-
"Attribute":$memorySpace,
135-
"bool":$mutableMemory
136-
), [{
137-
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory);
138-
}]>
139-
];
140-
let hasCustomAssemblyFormat = 1;
141-
}
142-
14395
// Result type of ExperimentalMakeTensorDescriptor
14496
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
14597
let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system";

include/triton/Dialect/Triton/IR/Types.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
#define GET_TYPEDEF_CLASSES
99
#include "triton/Dialect/Triton/IR/Types.h.inc"
1010

11-
#include "triton/Dialect/Triton/IR/TypeInterfaces.h.inc"
12-
1311
namespace mlir {
1412

1513
namespace triton {

include/triton/Dialect/TritonGPU/IR/Attributes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
66

77
#define GET_ATTRDEF_CLASSES
8-
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
8+
#include "triton/Dialect/TritonGPU/IR/AttrDefs.h.inc"
99

1010
#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_

include/triton/Dialect/TritonGPU/IR/CMakeLists.txt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@ add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc)
1212
add_public_tablegen_target(TritonGPUTableGen)
1313

1414
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
15-
mlir_tablegen(TritonGPUAttrInterfaces.h.inc -gen-attr-interface-decls)
16-
mlir_tablegen(TritonGPUAttrInterfaces.cpp.inc -gen-attr-interface-defs)
17-
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
18-
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
15+
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
16+
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
17+
mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls)
18+
mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs)
1919
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
2020
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
2121
add_public_tablegen_target(TritonGPUAttrDefsIncGen)
22+
23+
set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td)
24+
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
25+
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
26+
add_public_tablegen_target(TritonGPUTypeInterfacesIncGen)

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
// TritonGPU depends on Triton
1010
#include "triton/Dialect/Triton/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
12-
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
1312
#include "triton/Dialect/TritonGPU/IR/Types.h"
1413

1514
#define GET_OP_CLASSES
15+
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
1616
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
1717

1818
namespace mlir {

0 commit comments

Comments
 (0)