Skip to content

Commit f3a0aec

Browse files
Merge OpenAI Triton commit 3c2e6f8 (#5475)
This PR changes the Triton base from 9f21c06 to 3c2e6f8 (Oct 28). Pass rate: 94.91%
2 parents a9cd5c7 + 2f659b3 commit f3a0aec

File tree

177 files changed

+4397
-3033
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

177 files changed

+4397
-3033
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
299299
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
300300
${PYTHON_SRC_PATH}/ir.cc
301301
${PYTHON_SRC_PATH}/gluon_ir.cc
302+
${PYTHON_SRC_PATH}/linear_layout.cc
302303
${PYTHON_SRC_PATH}/passes.cc
303304
${PYTHON_SRC_PATH}/interpreter.cc
304305
${PYTHON_SRC_PATH}/llvm.cc

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ test-microbenchmark: all
6868
test-interpret: all
6969
cd python/test/unit && TRITON_INTERPRET=1 $(PYTEST) --tb=short -s -n 16 -m interpreter cuda language/test_core.py language/test_standard.py \
7070
language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \
71-
language/test_tuple.py runtime/test_autotuner.py::test_kwargs[False] \
71+
language/test_tuple.py runtime/test_launch.py runtime/test_autotuner.py::test_kwargs[False] \
7272
../../tutorials/06-fused-attention.py::test_op --device=cpu
7373

7474
.PHONY: test-proton

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,40 @@ struct ElementwiseOpConversion
166166
ConversionPatternRewriter &rewriter,
167167
Type elemTy, MultipleOperandsRange operands,
168168
Location loc) const {
169-
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
170-
adaptor.getAttributes().getValue())};
169+
return {DestOp::create(rewriter, loc, elemTy, operands[0],
170+
adaptor.getAttributes().getValue())};
171171
}
172172
};
173173

174+
template <typename SourceOp>
175+
struct ElementwiseToIntrinsicOpConversion
176+
: public ElementwiseOpConversionBase<
177+
SourceOp, ElementwiseToIntrinsicOpConversion<SourceOp>> {
178+
using Base =
179+
ElementwiseOpConversionBase<SourceOp, ElementwiseToIntrinsicOpConversion>;
180+
using OpAdaptor = typename Base::OpAdaptor;
181+
182+
using Base::Base;
183+
184+
explicit ElementwiseToIntrinsicOpConversion(
185+
LLVMTypeConverter &typeConverter,
186+
ModuleAxisInfoAnalysis &axisAnalysisPass, StringRef intrinsic,
187+
PatternBenefit benefit = patternBenefitDefault)
188+
: Base(typeConverter, axisAnalysisPass, benefit), intrinsic(intrinsic) {}
189+
190+
SmallVector<Value> createDestOps(SourceOp op, OpAdaptor adaptor,
191+
ConversionPatternRewriter &rewriter,
192+
Type elemTy, MultipleOperandsRange operands,
193+
Location loc) const {
194+
return {LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, elemTy,
195+
operands[0])
196+
.getResult(0)};
197+
}
198+
199+
private:
200+
StringRef intrinsic;
201+
};
202+
174203
} // namespace gpu
175204

176205
} // namespace mlir::triton

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 79 additions & 78 deletions
Large diffs are not rendered by default.

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
4545
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
4646
constexpr static char AttrTargetName[] = "ttg.target";
4747
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
48+
// FIXME: rename to match above
49+
constexpr static char kPartitionAttrName[] = "ttg.partition";
50+
constexpr static char kPartitionOutputsAttrName[] = "ttg.partition.outputs";
51+
constexpr static char kPartitionStagesAttrName[] = "ttg.partition.stages";
52+
constexpr static char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag";
4853

4954
// Find the contextual number of warps on which this operation is executed.
5055
int lookupNumWarps(Operation *op);
@@ -266,6 +271,12 @@ void dumpHWLayout(RankedTensorType tensorType);
266271
// Return a string representation of the layout of the tensor.
267272
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);
268273

274+
// Return a string representation of the shared layout of the tensor.
275+
std::string getSharedLayoutStr(LinearLayout &ll, bool useHWPointOfView);
276+
277+
// Return a string representation of the distributed layout of the tensor.
278+
std::string getDistributedLayoutStr(LinearLayout &ll, bool useHWPointOfView);
279+
269280
template <typename T>
270281
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);
271282

@@ -287,6 +298,10 @@ LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
287298
ShapedType dstTy);
288299
// Verify a memory allocation operation.
289300
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy);
301+
302+
std::optional<SetVector<int>> getPartitionIds(Operation *op);
303+
std::optional<int> getNumOutputPartitionIds(Operation *op);
304+
std::optional<SetVector<int>> getOutputPartitionIds(Operation *op, int idx);
290305
} // namespace mlir::triton::gpu
291306

292307
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ class ForOp;
1616
} // namespace scf
1717
} // namespace mlir
1818

19-
static constexpr char kPartitionAttrName[] = "ttg.partition";
20-
static constexpr char kPartitionOutputsAttrName[] = "ttg.partition.outputs";
21-
static constexpr char kPartitionStagesAttrName[] = "ttg.partition.stages";
22-
static constexpr char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag";
23-
2419
//===----------------------------------------------------------------------===//
2520
// PartitionSet
2621
//===----------------------------------------------------------------------===//
@@ -40,6 +35,7 @@ class Partition {
4035
ArrayRef<Operation *> getOps() const { return ops; }
4136
void addOp(Operation *op) { ops.push_back(op); }
4237
bool hasOp(Operation *op) const;
38+
bool empty() const { return ops.empty(); }
4339

4440
// Iterate the inputs of the partition. Input values are those that originate
4541
// from a different partition or a previous iteration of the current
@@ -127,8 +123,9 @@ void setPartition(Operation *op, const SetVector<Partition *> &partitions);
127123
// which does not work with Partition instances and iterate* functions, since
128124
// it does not keep the op attributes and the op list of a partition in sync.
129125
void setPartition(Operation *op, const SetVector<int> &partitionIds);
130-
131-
std::optional<SetVector<int>> getPartitionIds(Operation *op);
126+
void setPartitionOutputs(Operation *op,
127+
ArrayRef<SetVector<int>> partitionOutputsIds);
128+
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
132129

133130
} // namespace mlir::triton::gpu
134131

include/triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ template <typename OpT, typename... Args>
3636
OpT createInto(OpBuilder &b, Location loc,
3737
std::optional<SetVector<int>> partitionSet,
3838
StageCluster stageCluster, Args &&...args) {
39-
auto op = b.create<OpT>(loc, std::forward<Args>(args)...);
39+
auto op = OpT::create(b, loc, std::forward<Args>(args)...);
4040
if (partitionSet) {
4141
setPartition(op, *partitionSet);
4242
setStageCluster(b, op, stageCluster);

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,6 @@ getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
184184

185185
// Clean up attributes passing over schedules across stages in pipelining
186186
void removePipeliningAttributes(ModuleOp moduleOp);
187-
188-
// For LoadOp, DescriptorLoad, and DescriptorGather ops, determine if
189-
// they should be pipelined.
190-
bool isPipeliningBeneficial(Operation *op,
191-
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis,
192-
bool filterSmall = true);
193-
194187
} // namespace triton
195188
} // namespace mlir
196189

include/triton/Tools/LinearLayout.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,8 @@ inline std::ostream &operator<<(std::ostream &os, const ColumnAction &action) {
869869
return os;
870870
}
871871

872+
std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout);
873+
872874
} // namespace mlir::triton
873875

874876
#endif // TRITON_TOOLS_LINEARLAYOUT_H

lib/Analysis/Membar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void MembarOrFenceAnalysis::visitTerminator(
159159

160160
void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) {
161161
OpBuilder::InsertionGuard g(*builder);
162-
auto barrierOp = builder->create<triton::gpu::LocalBarrierOp>(op->getLoc());
162+
auto barrierOp = triton::gpu::LocalBarrierOp::create(*builder, op->getLoc());
163163
}
164164

165165
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,

0 commit comments

Comments
 (0)