Skip to content

Commit a1fe66b

Browse files
authored
Merge OpenAI commit 1b27b93 (#5202)
This PR change the Triton base from e15cb57 to 1b27b93 (Sep 22). Pass rate: 97.22%
2 parents 539ac90 + d9ad3c1 commit a1fe66b

File tree

30 files changed

+468
-318
lines changed

30 files changed

+468
-318
lines changed

.github/workflows/wheels.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ jobs:
8888
export CIBW_MANYLINUX_AARCH64_IMAGE="quay.io/pypa/manylinux_2_28_${{ matrix.config.arch }}:latest"
8989
fi
9090
91-
export CIBW_BUILD="cp3{9,10,11,12,13,13t,14,14t}-manylinux_${{ matrix.config.arch }}"
92-
export CIBW_SKIP="cp{35,36,37,38}-*"
91+
export CIBW_BUILD="cp3{10,11,12,13,13t,14,14t}-manylinux_${{ matrix.config.arch }}"
92+
export CIBW_SKIP="cp{35,36,37,38,39}-*"
9393
export CIBW_ENABLE=cpython-freethreading
9494
python3 -m cibuildwheel . --output-dir wheelhouse
9595

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ You can install the latest stable release of Triton from pip:
3838
pip install triton
3939
```
4040

41-
Binary wheels are available for CPython 3.10-3.13.
41+
Binary wheels are available for CPython 3.10-3.14.
4242

4343
# Install from source
4444

@@ -262,6 +262,15 @@ export TRITON_OVERRIDE_DIR=<override_dir>
262262
# Step 4: Run the kernel again to see the overridden result
263263
```
264264

265+
**Compiler Pipeline Inspection Steps**
266+
To introspect the pipeline `add_stages`, before running your kernels, simply set
267+
the add_stages_inspection_hook like so:
268+
269+
```python
270+
def inspect_stages(_self, stages, options, language, capability):
271+
# inspect or modify add_stages here
272+
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages
273+
```
265274

266275
# Changelog
267276

docs/getting-started/installation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ You can install the latest stable release of Triton from pip:
1414
1515
pip install triton
1616
17-
Binary wheels are available for CPython 3.10-3.13.
17+
Binary wheels are available for CPython 3.10-3.14.
1818

1919
-----------
2020
From Source

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Types.h"
6+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
67
#include "triton/Tools/LayoutUtils.h"
78

89
using namespace mlir;

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,7 +1670,7 @@ void SharedLinearEncodingAttr::print(AsmPrinter &printer) const {
16701670
layout.sublayout({kOffset}, llvm::to_vector(layout.getOutDimNames()));
16711671
}
16721672
printLinearLayout(printer, layout);
1673-
printer << "}, alignment = " << getAlignment() << "}>";
1673+
printer << "}, alignment = " << getAlignment() << ">";
16741674
}
16751675

16761676
Attribute SharedLinearEncodingAttr::parse(AsmParser &parser, Type type) {
@@ -2701,19 +2701,17 @@ struct TritonGPUInferLayoutInterface
27012701
}
27022702

27032703
if (auto enc = dyn_cast<NVMMASharedEncodingAttr>(operandEncoding)) {
2704-
if (failed(checkRank(enc.getRank())))
2705-
return failure();
2706-
if (order != ArrayRef<int32_t>({1, 0})) {
2707-
return emitOptionalError(
2708-
loc, "NVMMSharedEncoding can only be transposed in 2D");
2709-
}
2704+
if (order == ArrayRef<int32_t>({1, 0})) {
2705+
if (failed(checkRank(enc.getRank())))
2706+
return failure();
27102707

2711-
CTALayoutAttr ctaLayout =
2712-
permuteCTALayout(ctx, enc.getCTALayout(), order);
2713-
resultEncoding = NVMMASharedEncodingAttr::get(
2714-
ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(),
2715-
enc.getElementBitWidth(), enc.getFp4Padded(), ctaLayout);
2716-
return success();
2708+
CTALayoutAttr ctaLayout =
2709+
permuteCTALayout(ctx, enc.getCTALayout(), order);
2710+
resultEncoding = NVMMASharedEncodingAttr::get(
2711+
ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(),
2712+
enc.getElementBitWidth(), enc.getFp4Padded(), ctaLayout);
2713+
return success();
2714+
}
27172715
}
27182716

27192717
if (auto enc = dyn_cast<BlockedEncodingAttr>(operandEncoding)) {
@@ -2729,20 +2727,25 @@ struct TritonGPUInferLayoutInterface
27292727
applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout);
27302728
return success();
27312729
}
2730+
// Generic case
2731+
auto padded = dyn_cast<PaddedSharedEncodingAttr>(operandEncoding);
27322732

2733-
if (auto enc = dyn_cast<PaddedSharedEncodingAttr>(operandEncoding)) {
2734-
if (failed(checkRank(enc.getRank())))
2735-
return failure();
2736-
const auto &transLL =
2737-
transposeLinearLayout(enc.getLinearComponent(), order);
2738-
resultEncoding = PaddedSharedEncodingAttr::get(
2739-
ctx, enc.getIntervals(), enc.getPaddings(), transLL);
2740-
return success();
2741-
}
2742-
2743-
auto ll = toLinearLayout(shape, operandEncoding);
2733+
auto ll = padded ? padded.getLinearComponent()
2734+
: toLinearLayout(shape, operandEncoding);
2735+
if (failed(checkRank(ll.getNumOutDims())))
2736+
return failure();
27442737
auto transposedLl = transposeLinearLayout(ll, order);
2745-
resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl));
2738+
if (isa<DistributedEncodingTrait>(operandEncoding)) {
2739+
resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl));
2740+
} else if (padded) {
2741+
resultEncoding = PaddedSharedEncodingAttr::get(ctx, padded.getIntervals(),
2742+
padded.getPaddings(),
2743+
std::move(transposedLl));
2744+
} else {
2745+
auto shared = cast<SharedEncodingTrait>(operandEncoding);
2746+
resultEncoding = SharedLinearEncodingAttr::get(
2747+
ctx, std::move(transposedLl), shared.getAlignment());
2748+
}
27462749
return success();
27472750
}
27482751

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -531,40 +531,44 @@ static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,
531531
Attribute srcEnc,
532532
ArrayRef<int64_t> dstShape,
533533
Attribute &dstEnc) {
534+
// TODO Delete this once SharedLinearEncodingAttr is more widely supported.
534535
if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
535-
// TODO: supporting reshape of CTA layouts is non-trivial.
536-
if (getNumCTAs(mmaEncoding) > 1)
537-
return failure();
538-
int innerDimDst =
539-
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
540-
int innerDimSrc =
541-
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
542-
// For now disallow reshape of the inner dimension.
543-
if (innerDimDst != innerDimSrc)
544-
return failure();
545536
auto *ctx = srcEnc.getContext();
546-
547-
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
548-
auto CTALayout = CTALayoutAttr::get(
549-
ctx,
550-
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
551-
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
552-
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
553-
dstEnc = NVMMASharedEncodingAttr::get(
554-
ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(),
555-
mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(),
556-
CTALayout);
557-
// Big guns, check linear layouts are equivalent
558-
// We disallow reshaping memdesc_subslice in the verifier
559-
// so allocShape == shape
560-
auto srcLL = toLinearLayout(srcShape, srcEnc);
561-
auto dstLL = toLinearLayout(dstShape, dstEnc);
562-
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
563-
return failure();
537+
if (getNumCTAs(mmaEncoding) == 1) {
538+
int innerDimDst =
539+
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
540+
int innerDimSrc =
541+
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
542+
// We can keep an NVMMAShared encoding only if the innermost dimension is
543+
// preserved. Otherwise fall back to the generic shared-linear encoding
544+
// logic below.
545+
if (innerDimDst == innerDimSrc) {
546+
auto CTALayout = CTALayoutAttr::get(
547+
ctx,
548+
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
549+
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
550+
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
551+
auto candidateEncoding = NVMMASharedEncodingAttr::get(
552+
ctx, mmaEncoding.getSwizzlingByteWidth(),
553+
mmaEncoding.getTransposed(), mmaEncoding.getElementBitWidth(),
554+
mmaEncoding.getFp4Padded(), CTALayout);
555+
auto srcLL = toLinearLayout(srcShape, srcEnc);
556+
auto dstLL = toLinearLayout(dstShape, candidateEncoding);
557+
if (reshapeLayout(ctx, srcLL, dstShape) == dstLL) {
558+
dstEnc = candidateEncoding;
559+
return success();
560+
}
561+
}
564562
}
565-
return success();
566563
}
567-
return failure();
564+
565+
// Generic LL case
566+
auto sharedEnc = cast<SharedEncodingTrait>(srcEnc);
567+
auto *ctx = srcEnc.getContext();
568+
auto srcLL = toLinearLayout(srcShape, srcEnc);
569+
auto dstLL = reshapeLayout(ctx, srcLL, dstShape);
570+
dstEnc = SharedLinearEncodingAttr::get(ctx, dstLL, sharedEnc.getAlignment());
571+
return success();
568572
}
569573

570574
LogicalResult MemDescReshapeOp::inferReturnTypes(

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,24 @@ class UseShmemForScales
257257
if (!isTmemCopyCompatible(localLoad.getSrc().getType(), usesTMAload))
258258
return failure();
259259

260-
opOperand.assign(localLoad.getSrc());
260+
PatternRewriter::InsertionGuard guard(rewriter);
261+
rewriter.setInsertionPoint(tmemAlloc);
262+
263+
Value shared = localLoad.getSrc();
264+
265+
Value reshaped5D = rewriter.create<MemDescReshapeOp>(
266+
reshapeOp5D.getLoc(), shared, reshape5DShape);
267+
SmallVector<int32_t> transposeOrder32(transposeOrder.begin(),
268+
transposeOrder.end());
269+
Value transposed = rewriter.create<MemDescTransOp>(
270+
transOp.getLoc(), reshaped5D, transposeOrder32);
271+
SmallVector<int64_t> scale2DShapeVec(scale2DShape.begin(),
272+
scale2DShape.end());
273+
Value reshaped2D = rewriter.create<MemDescReshapeOp>(
274+
reshapeOp2D.getLoc(), transposed, scale2DShapeVec);
275+
276+
opOperand.assign(reshaped2D);
277+
rewriter.eraseOp(tmemAlloc);
261278
return success();
262279
}
263280

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,20 @@ LogicalResult TMEMCopyOp::verify() {
680680
getSrc().getType().getMemorySpace()))
681681
return emitOpError("The source must be a shared memory buffer");
682682

683+
auto srcTy = cast<triton::gpu::MemDescType>(getSrc().getType());
684+
auto dstTy = cast<triton::gpu::MemDescType>(getDst().getType());
685+
if (srcTy.getShape() != dstTy.getShape())
686+
return emitOpError("source shape ")
687+
<< srcTy.getShape() << " must match destination shape "
688+
<< dstTy.getShape();
689+
683690
if (getBarrier() && !isa<triton::gpu::SharedMemorySpaceAttr>(
684691
getBarrier().getType().getMemorySpace())) {
685692
return emitOpError("The optional barrier should be a shared memory buffer");
686693
}
687694
if (!getDst().getType().getMutableMemory()) {
688695
return emitOpError("Cannot copy into an immutable alloc");
689696
}
690-
auto srcTy = cast<triton::gpu::MemDescType>(getSrc().getType());
691697
auto sharedEnc =
692698
dyn_cast<triton::gpu::SharedEncodingTrait>(srcTy.getEncoding());
693699
if (sharedEnc.getAlignment() < 16) {
@@ -700,21 +706,16 @@ LogicalResult TMEMCopyOp::verify() {
700706
if (numCTAs != 1)
701707
return emitOpError("NYI: Only one CTA is supported for now.");
702708

709+
// Fp4 we could lift if we needed
703710
auto nvmmaEnc =
704711
dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(srcTy.getEncoding());
705-
if (!nvmmaEnc) {
706-
return emitOpError("Source must have nvmma layout.");
707-
}
708-
// Fp4 we could lift if we needed
709-
if (nvmmaEnc.getTransposed() || nvmmaEnc.getFp4Padded())
712+
if (nvmmaEnc && (nvmmaEnc.getTransposed() || nvmmaEnc.getFp4Padded())) {
710713
return emitOpError("The source should not be transposed or padded");
714+
}
711715
if (isa<TensorMemoryScalesEncodingAttr>(getDst().getType().getEncoding())) {
712-
if (nvmmaEnc.getSwizzlingByteWidth() != 0) {
716+
if (nvmmaEnc && nvmmaEnc.getSwizzlingByteWidth() != 0) {
713717
return emitOpError("The source should not be swizzled for now");
714718
}
715-
if (!triton::gpu::isInnermostContiguous(srcTy, 512)) {
716-
return emitOpError("The source must be in a row-major order.");
717-
}
718719
} else {
719720
if (getSrc().getType().getShape() != getDst().getType().getShape()) {
720721
return emitOpError(
@@ -728,7 +729,7 @@ LogicalResult TMEMCopyOp::verify() {
728729
if (tmemEnc.getBlockM() != 128) {
729730
return emitOpError("Tmem layout ahouls have M=128.");
730731
}
731-
if (nvmmaEnc.getSwizzlingByteWidth() == 0) {
732+
if (nvmmaEnc && nvmmaEnc.getSwizzlingByteWidth() == 0) {
732733
return emitOpError("Source layout should be swizzled.");
733734
}
734735
// When we lift this, we should make sure we handle unpacked cleanly

python/src/gluon_ir.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ struct GluonLayouts {
9898
py::handle NVMMADistributedLayout;
9999
py::handle NVMMASharedLayout;
100100
py::handle SwizzledSharedLayout;
101+
py::handle SharedLinearLayout;
101102
py::handle AMDMFMALayout;
102103
py::handle AMDWMMALayout;
103104
py::handle PaddedSharedLayout;
@@ -119,6 +120,8 @@ struct GluonLayouts {
119120
NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release();
120121
SwizzledSharedLayout =
121122
py::object(layouts.attr("SwizzledSharedLayout")).release();
123+
SharedLinearLayout =
124+
py::object(layouts.attr("SharedLinearLayout")).release();
122125
AMDMFMALayout = py::object(amdLayouts.attr("AMDMFMALayout")).release();
123126
AMDWMMALayout = py::object(amdLayouts.attr("AMDWMMALayout")).release();
124127
PaddedSharedLayout =
@@ -203,6 +206,14 @@ py::object layoutToGluon(Attribute layout) {
203206
toStdVector(ctaLayout.getCTAsPerCGA()),
204207
toStdVector(ctaLayout.getCTASplitNum()),
205208
toStdVector(ctaLayout.getCTAOrder()));
209+
} else if (auto sharedLl = dyn_cast<ttg::SharedLinearEncodingAttr>(layout)) {
210+
const auto &ll = sharedLl.getLinearLayout();
211+
auto ctx = layout.getContext();
212+
auto kOffset = mlir::StringAttr::get(ctx, "offset");
213+
auto kBlock = mlir::StringAttr::get(ctx, "block");
214+
return layouts.SharedLinearLayout(
215+
toStdVector(ll.getBases().lookup(kOffset)),
216+
toStdVector(ll.getBases().lookup(kBlock)), sharedLl.getAlignment());
206217
} else if (auto autoEnc = dyn_cast<gluon::AutoEncodingAttr>(layout)) {
207218
return layouts.AutoLayout();
208219
} else if (auto amdMfma = dyn_cast<ttg::AMDMfmaEncodingAttr>(layout)) {
@@ -410,14 +421,13 @@ void init_gluon_ir(py::module &&m) {
410421
.def("get_shared_linear_layout",
411422
[](GluonOpBuilder &self, std::vector<std::vector<int>> &offsetBases,
412423
std::vector<std::vector<int>> &blockBases,
413-
std::vector<int64_t> &shape, unsigned alignment) -> Attribute {
424+
unsigned alignment) -> Attribute {
414425
auto ctx = self.getContext();
415426
auto kOffset = mlir::StringAttr::get(ctx, "offset");
416427
auto kBlock = mlir::StringAttr::get(ctx, "block");
428+
auto outDims = tt::standardOutDimNames(ctx, offsetBases[0].size());
417429
auto ll = tt::LinearLayout(
418-
{{kOffset, offsetBases}, {kBlock, blockBases}},
419-
tt::standardOutDimPairs(ctx, shape),
420-
/*requireSurjective=*/true);
430+
{{kOffset, offsetBases}, {kBlock, blockBases}}, outDims);
421431
return self.getChecked<ttg::SharedLinearEncodingAttr>(ctx, ll,
422432
alignment);
423433
})
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import triton
2+
from triton import knobs
3+
4+
import os
5+
import pathlib
6+
7+
8+
def test_inspection(monkeypatch, tmp_path: pathlib.Path):
9+
stage_name = 'make_ttgir'
10+
curr_repro_path = tmp_path / ("repro_prefix." + stage_name + ".repro.mlir")
11+
repro_path = tmp_path / "repro_prefix"
12+
13+
monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1")
14+
monkeypatch.setenv("TRITON_REPRODUCER_PATH", str(repro_path))
15+
16+
inspect_stages_hook_called = False
17+
make_ttgir_wrapper_called = False
18+
19+
def inspect_stages_hook(self, stages, options, language, capability):
20+
nonlocal inspect_stages_hook_called
21+
inspect_stages_hook_called = True
22+
23+
def make_ttgir_wrapper(src, metadata, options, capability):
24+
nonlocal make_ttgir_wrapper_called
25+
make_ttgir_wrapper_called = True
26+
return self.make_ttgir(src, metadata, options, capability)
27+
28+
stages["ttgir"] = lambda src, metadata: make_ttgir_wrapper(src, metadata, options, capability)
29+
30+
@triton.jit
31+
def k1():
32+
return
33+
34+
@triton.jit
35+
def k2():
36+
return
37+
38+
# Run once to get the clean/golden repro dump
39+
k1[(1, )]()
40+
assert not inspect_stages_hook_called and not make_ttgir_wrapper_called
41+
assert os.path.exists(curr_repro_path)
42+
golden_repro = curr_repro_path.read_text()
43+
curr_repro_path.unlink()
44+
45+
# Setup hook and call again, check if hooks got called
46+
knobs.runtime.add_stages_inspection_hook = inspect_stages_hook
47+
k2[(1, )]()
48+
assert inspect_stages_hook_called and make_ttgir_wrapper_called
49+
assert os.path.exists(curr_repro_path)
50+
hook_repro = curr_repro_path.read_text()
51+
52+
# Check that repros match
53+
assert golden_repro.replace('k1', 'dummy') == hook_repro.replace('k2', 'dummy')

0 commit comments

Comments
 (0)