Skip to content

Commit c056d82

Browse files
Merge OpenAI Triton commit 815b2a4 (#4790)
This PR change the Triton base from d183197 to 815b2a4 (Jul 17). Pass rate: 98.62%
2 parents 12fdf3f + 58bdfaf commit c056d82

File tree

60 files changed

+2075
-1196
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+2075
-1196
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ jobs:
3131
CCACHE_COMPRESS: "true"
3232
container:
3333
image: ${{ matrix.image }}
34-
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
34+
# Cache save/restore is on the host machine at directory /home/runner/.triton, while in the docker
35+
# container expect it at /github/home/.triton. So map here to make sure visible in docker.
36+
options: >-
37+
--device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
38+
--volume /home/runner/.triton:/github/home/.triton
3539
steps:
3640
- name: Checkout
3741
uses: actions/checkout@v4
@@ -54,7 +58,6 @@ jobs:
5458
echo "llvm=$(cat $llvm_file | cut -c 1-8)" >> $GITHUB_OUTPUT
5559
echo "nvidia=$(sha256sum $nvidia_file | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT
5660
echo "json=$(cat $json_file)" >> $GITHUB_OUTPUT
57-
echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT
5861
shell: bash
5962
- name: Cache build dependencies
6063
uses: actions/cache@v4
@@ -162,5 +165,5 @@ jobs:
162165
# Always cleanup the worker, even if builds or tests failed
163166
if: always()
164167
run: |
165-
rm -rf ~/.triton
168+
rm -rf ~/.triton/cache
166169
rm -rf ~/.ccache

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,21 @@ class SharedMemoryObject {
352352
SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
353353
RewriterBase &rewriter) const;
354354

355+
// Returns a mask representing all the bits of the memdesc offsets that
356+
// may be modified by an affine offset coming from a memdesc_subview.
357+
// The offsets are considered to be in the type of the memdesc.
358+
// For padded layouts, we return the offsets without padding.
359+
static uint64_t getMaskSpanOffsets(triton::gpu::MemDescType srcTy);
360+
361+
// Returns whether the shared memory access had a memdesc_subview
362+
// that is rank-preserving (soon to be called memdesc_slice)
363+
static bool isAffineSharedMemoryAccess(triton::gpu::MemDescType srcTy) {
364+
return getMaskSpanOffsets(srcTy) != 0;
365+
}
366+
367+
Value getShmemOffset(Location loc, RewriterBase &rewriter,
368+
triton::gpu::MemDescType srcTy) const;
369+
355370
// TODO(Keren): deprecate the method once AMD backend has cleaned up
356371
Value getCSwizzleOffset(int dim) const {
357372
assert(dim >= 0 && dim < offsets.size());
@@ -462,7 +477,6 @@ std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc);
462477
// -----------------------------------------------------------------------
463478
using LLVM::SharedMemoryObject;
464479
using ::mlir::LLVM::delinearize;
465-
using ::mlir::LLVM::SharedMemoryObject;
466480
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
467481
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
468482
using ::mlir::triton::gpu::BlockedEncodingAttr;
@@ -474,24 +488,6 @@ using ::mlir::triton::gpu::SliceEncodingAttr;
474488
Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
475489
ArrayRef<Value> strides);
476490

477-
/// Extend 2d shared object to 3d.
478-
///
479-
/// If tensor has 3 dimensions, returns original shared object.
480-
/// If tensor shape is [M, N], return shared object describing shape [1, M, N]
481-
///
482-
/// This Function is used to simplify processing of 2d and 3d dot operands,
483-
/// particularly in the conversion of local_load operation.
484-
///
485-
/// \param rewriter
486-
/// \param loc
487-
/// \param smemObj
488-
/// \param shape shape of a tensor represented by smemObj
489-
/// \returns shared object describing 3d tensor
490-
SharedMemoryObject
491-
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
492-
SharedMemoryObject smemObj,
493-
ArrayRef<int64_t> shape);
494-
495491
// "Applies" the given layout by computing layout(indices) and returning the
496492
// resulting Values.
497493
//
@@ -568,7 +564,8 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
568564
SmallVector<Value>
569565
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
570566
ArrayRef<Value> valsArray, // Input for store, output for load
571-
Type llvmElemTy, Value smemBase,
567+
Type llvmElemTy, Value smemBase, Value affineOffset,
568+
uint64_t maskSpanAffineOffset,
572569
ConversionPatternRewriter &rewriter,
573570
const TargetInfoBase &targetInfo);
574571

@@ -578,20 +575,21 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
578575
SmallVector<Value> lowerLdSt(
579576
Location loc, MLIRContext *ctx, LinearLayout cvt,
580577
ArrayRef<Value> valsArray, // Input for store, output for load
581-
Type llvmElemTy, Value smemBase, ConversionPatternRewriter &rewriter,
578+
Type llvmElemTy, Value smemBase, Value affineOffset,
579+
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
582580
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
583581
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
584582
ArrayRef<Value>, Value, int, VectorType)>
585583
lowerInst);
586584

587585
// Lower local_load/local_store via ld.shared/st.shared
588-
SmallVector<Value> lowerLocalLdSt(Location loc, MLIRContext *ctx,
589-
// Map from registers to offset
590-
LinearLayout cvt, ArrayRef<Value> valsArray,
591-
// Input for store, output for load
592-
Type llvmElemTy, Value smemBase,
593-
ConversionPatternRewriter &rewriter,
594-
const TargetInfoBase &targetInfo);
586+
SmallVector<Value>
587+
lowerLocalLdSt(Location loc, MLIRContext *ctx,
588+
LinearLayout cvt, // Map from registers to offset
589+
ArrayRef<Value> valsArray, // Input for store, empty for load
590+
Type llvmElemTy, triton::gpu::MemDescType srcTy,
591+
SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
592+
const TargetInfoBase &targetInfo);
595593

596594
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
597595
RewriterBase &rewriter);

include/triton/Dialect/TritonGPU/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
33

44
#include "mlir/Pass/Pass.h"
5+
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
56
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
67

78
namespace mlir {

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def TritonGPUAutomaticWarpSpecialization : Pass<"tritongpu-automatic-warp-specia
106106
"mlir::triton::gpu::TritonGPUDialect",
107107
"mlir::scf::SCFDialect",
108108
"mlir::arith::ArithDialect",
109-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
109+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
110+
"triton::nvws::NVWSDialect"
110111
];
111112

112113
let options = [
@@ -143,7 +144,10 @@ def TritonGPUPartitionLoops : Pass<"tritongpu-partition-loops", "mlir::ModuleOp"
143144
between any of the partitions.
144145
}];
145146

146-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
147+
let dependentDialects = [
148+
"mlir::triton::gpu::TritonGPUDialect",
149+
"triton::nvws::NVWSDialect"
150+
];
147151
}
148152

149153
def TritonGPUOptimizePartitionWarps : Pass<"tritongpu-optimize-partition-warps", "mlir::ModuleOp"> {
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
2+
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
3+
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
7+
namespace mlir::triton::nvidia_gpu {
8+
9+
LogicalResult verifyBarrierType(Operation *op,
10+
mlir::triton::gpu::MemDescType barrierType);
11+
12+
}
13+
14+
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_

include/triton/Tools/LayoutUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128128
std::pair<int64_t, ColumnAction>
129-
actionAdditiveStrides(const LinearLayout &layout);
129+
actionAdditiveStrides(const LinearLayout &layout, uint64_t maskSpanOffsets);
130130

131131
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
132132
// the same broadcasting as layout

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
202202

203203
assert(permutedInVals.size() == tileSize * nReps);
204204
SmallVector<Value> outVals;
205+
auto affineOffset = b.i32_val(0);
206+
auto maskSpanAffineOffset = 0;
205207
for (int i = 0; i < nReps; ++i) {
206208
if (i > 0)
207209
b.barrier();
@@ -210,11 +212,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
210212
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
211213
// Store
212214
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
213-
rewriter, targetInfo);
215+
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
214216
b.barrier();
215217
// Load
216218
SmallVector<Value> tileOutVals = lowerLdStShared(
217-
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, rewriter, targetInfo);
219+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, affineOffset,
220+
maskSpanAffineOffset, rewriter, targetInfo);
218221
llvm::append_range(outVals, tileOutVals);
219222
}
220223

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
5353
auto kWarp = str_attr("warp");
5454
auto kOffset = str_attr("offset");
5555
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
56-
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, smemObj.getBase(), rewriter,
57-
targetInfo);
56+
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
57+
rewriter, targetInfo);
5858

5959
return success();
6060
}
@@ -177,10 +177,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
177177
auto regTy = cast<RankedTensorType>(regVal.getType());
178178
auto typeConverter = getTypeConverter();
179179

180-
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
181-
loc, adaptor.getSrc(),
182-
typeConverter->convertType(memDescTy.getElementType()), rewriter);
183-
auto llvmElemTy = typeConverter->convertType(regTy.getElementType());
180+
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
181+
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
182+
llvmElemTy, rewriter);
184183

185184
// See [Legacy local_load/local_store]
186185
if (!targetInfo.isCuda()) {
@@ -206,8 +205,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
206205
auto kOffset = str_attr("offset");
207206
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
208207

209-
auto outVals = lowerLocalLdSt(op.getLoc(), ctx, cvt, {}, llvmElemTy,
210-
smemObj.getBase(), rewriter, targetInfo);
208+
auto outVals = lowerLocalLdSt(loc, ctx, cvt, {}, llvmElemTy, memDescTy,
209+
smemObj, rewriter, targetInfo);
211210

212211
Value result = packLLElements(loc, typeConverter, outVals, rewriter, regTy);
213212
rewriter.replaceOp(op, result);

0 commit comments

Comments
 (0)