Skip to content

Commit 6dfd22f

Browse files
Merge OpenAI Triton commit e6b9efd (#4282)
This PR change the Triton base from 26b45d8 to e6b9efd (May 16). Pass rate: 95.34%
2 parents 0dce8de + 528e45e commit 6dfd22f

File tree

18 files changed

+128
-752
lines changed

18 files changed

+128
-752
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ test-unit: all
3434
--ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
3535
$(PYTEST) -s -n 8 python/test/unit/language/test_subprocess.py
3636
$(PYTEST) -s -n 8 python/test/unit/test_debug.py --forked
37+
$(PYTEST) -s -n 8 python/triton_kernels/tests/
3738
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
3839
# Run cuda/test_flashattention.py separately to avoid out of gpu memory
3940
$(PYTEST) -s python/test/unit/cuda/test_flashattention.py

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ See [`python/triton/knobs.py`](python/triton/knobs.py) for the full list of conf
212212
- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1.
213213
- `TRITON_F32_DEFAULT` sets the default input precision of `tl.dot` when using 32-bit floats, which can be either `ieee`, `tf32`, or `tf32x3`.
214214
- `TRITON_FRONT_END_DEBUGGING=1` disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
215-
- `TRITON_STRIP_DEBUG_INFO` removes all debug information from the module, including location information
215+
- `TRITON_DISABLE_LINE_INFO=1` removes all line information from the module
216216

217217
N.B. Some of these environment variables don't have a knob in `knobs.py`-- those are only relevant to the C++ layer(s), hence they don't exist in the python layer.
218218

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
322322

323323
// ---- begin Ampere & Hopper ----
324324
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
325-
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
325+
int perPhase = 128 / (std::max<int>(1, shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()));
326326
perPhase = std::max<int>(perPhase, 1);
327327
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
328328
int vecWidth = 32 / typeWidthInBit;

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,40 +1539,35 @@ std::optional<LinearLayout>
15391539
chooseMfmaLikeStoreLayout(RankedTensorType valType) {
15401540
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding());
15411541

1542-
// Currently support transposed [B]F16 MFMA32x32 on CDNA4
1542+
// We currently only support transposed [B]F16 MFMA32x32 on CDNA4.
15431543
bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
15441544
Type elemType = valType.getElementType();
15451545
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
15461546
mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() &&
15471547
isMfma32))
15481548
return {};
15491549

1550-
MLIRContext *ctx = mfmaLayout.getContext();
1551-
StringAttr kRegister = S("register");
1552-
StringAttr kLane = S("lane");
1553-
StringAttr kWarp = S("warp");
1554-
StringAttr kBlock = S("block");
1555-
1556-
SmallVector<unsigned> order = getDefaultMmaOrder(mfmaLayout);
1557-
auto standardOutDims = standardOutDimNames(ctx, 2);
1558-
// We make each thread handle 8 consecutive elements to enable 128-bit
1559-
// global stores for [b]f16 types and keep the thread pattern in each lane
1560-
// similar to the canonical mfmaLayout.
1561-
LinearLayout mfma8Layout = LinearLayout::empty();
1562-
mfma8Layout =
1563-
LinearLayout({{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
1564-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
1565-
{kWarp, {}},
1566-
{kBlock, {}}},
1567-
{standardOutDims[order[0]], standardOutDims[order[1]]});
1568-
1569-
LinearLayout warpLayout =
1570-
identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order);
1571-
LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) *
1572-
warpLayout.transposeOuts(standardOutDims);
1573-
mfma8Layout = combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(),
1574-
valType.getShape());
1575-
return mfma8Layout;
1550+
auto valShape = valType.getShape();
1551+
LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape);
1552+
auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames());
1553+
StringAttr dimM = mfmaOutDims[0];
1554+
StringAttr dimN = mfmaOutDims[1];
1555+
1556+
auto swapLL = LinearLayout::empty();
1557+
// The rows are kept as is with an identity linear layout.
1558+
swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM);
1559+
// In transposed mfma32 layout, each thread holds 4 consecutive values along N
1560+
// dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11
1561+
// (owned by thread 0-31) every 16 columns to make each thread holds 8
1562+
// elements. This would mean exchange the 2nd and 3rd basis vector from an
1563+
// identity linear layout.
1564+
std::vector<std::vector<int32_t>> dimNBases(mfmaLL.getOutDimSizeLog2(dimN));
1565+
std::generate(dimNBases.begin(), dimNBases.end(),
1566+
[i = 0]() mutable { return std::vector<int32_t>{1 << i++}; });
1567+
std::swap(dimNBases[2], dimNBases[3]);
1568+
swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN});
1569+
1570+
return mfmaLL.compose(swapLL);
15761571
}
15771572

15781573
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,

python/src/passes.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ void init_triton_analysis(py::module &&m) {
2525

2626
void init_triton_passes_common(py::module &&m) {
2727
using namespace mlir;
28-
ADD_PASS_WRAPPER_0("add_strip_debug_info", createStripDebugInfoPass);
2928
ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass);
3029
ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass);
3130
ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass);

python/test/unit/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,16 @@ def fresh_knobs_except_libraries(monkeypatch):
138138
yield fresh_function()
139139
finally:
140140
reset_function()
141+
142+
143+
@pytest.fixture
144+
def with_allocator():
145+
import triton
146+
from triton.runtime._allocation import NullAllocator
147+
from triton._internal_testing import default_alloc_fn
148+
149+
triton.set_allocator(default_alloc_fn)
150+
try:
151+
yield
152+
finally:
153+
triton.set_allocator(NullAllocator())

0 commit comments

Comments
 (0)