Skip to content

Commit 682c23a

Browse files
Merge commit 'e6b9efdff5c34c946990a0bd40c4b8ed02fe71fd'
2 parents 0dce8de + e6b9efd commit 682c23a

File tree

20 files changed

+232
-1467
lines changed

20 files changed

+232
-1467
lines changed

Makefile

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ test-cpp:
3030

3131
.PHONY: test-python
3232
test-unit: all
33-
cd python/test/unit && $(PYTEST) -s -n 8 --ignore=cuda/test_flashattention.py \
34-
--ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
33+
cd python/test/unit && $(PYTEST) -s -n 8 --ignore=language/test_line_info.py \
34+
--ignore=language/test_subprocess.py --ignore=test_debug.py
3535
$(PYTEST) -s -n 8 python/test/unit/language/test_subprocess.py
3636
$(PYTEST) -s -n 8 python/test/unit/test_debug.py --forked
37+
$(PYTEST) -s -n 8 python/triton_kernels/tests/
3738
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
38-
# Run cuda/test_flashattention.py separately to avoid out of gpu memory
39-
$(PYTEST) -s python/test/unit/cuda/test_flashattention.py
39+
# Run attention separately to avoid out of gpu memory
40+
$(PYTEST) -s python/tutorials/06-fused-attention.py
4041
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4142
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
4243

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ See [`python/triton/knobs.py`](python/triton/knobs.py) for the full list of conf
212212
- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1.
213213
- `TRITON_F32_DEFAULT` sets the default input precision of `tl.dot` when using 32-bit floats, which can be either `ieee`, `tf32`, or `tf32x3`.
214214
- `TRITON_FRONT_END_DEBUGGING=1` disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
215-
- `TRITON_STRIP_DEBUG_INFO` removes all debug information from the module, including location information
215+
- `TRITON_DISABLE_LINE_INFO=1` removes all line information from the module
216216

217217
N.B. Some of these environment variables don't have a knob in `knobs.py`-- those are only relevant to the C++ layer(s), hence they don't exist in the python layer.
218218

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
322322

323323
// ---- begin Ampere & Hopper ----
324324
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
325-
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
325+
int perPhase = 128 / (std::max<int>(1, shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()));
326326
perPhase = std::max<int>(perPhase, 1);
327327
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
328328
int vecWidth = 32 / typeWidthInBit;

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,40 +1539,35 @@ std::optional<LinearLayout>
15391539
chooseMfmaLikeStoreLayout(RankedTensorType valType) {
15401540
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding());
15411541

1542-
// Currently support transposed [B]F16 MFMA32x32 on CDNA4
1542+
// We currently only support transposed [B]F16 MFMA32x32 on CDNA4.
15431543
bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
15441544
Type elemType = valType.getElementType();
15451545
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
15461546
mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() &&
15471547
isMfma32))
15481548
return {};
15491549

1550-
MLIRContext *ctx = mfmaLayout.getContext();
1551-
StringAttr kRegister = S("register");
1552-
StringAttr kLane = S("lane");
1553-
StringAttr kWarp = S("warp");
1554-
StringAttr kBlock = S("block");
1555-
1556-
SmallVector<unsigned> order = getDefaultMmaOrder(mfmaLayout);
1557-
auto standardOutDims = standardOutDimNames(ctx, 2);
1558-
// We make each thread handle 8 consecutive elements to enable 128-bit
1559-
// global stores for [b]f16 types and keep the thread pattern in each lane
1560-
// similar to the canonical mfmaLayout.
1561-
LinearLayout mfma8Layout = LinearLayout::empty();
1562-
mfma8Layout =
1563-
LinearLayout({{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
1564-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
1565-
{kWarp, {}},
1566-
{kBlock, {}}},
1567-
{standardOutDims[order[0]], standardOutDims[order[1]]});
1568-
1569-
LinearLayout warpLayout =
1570-
identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order);
1571-
LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) *
1572-
warpLayout.transposeOuts(standardOutDims);
1573-
mfma8Layout = combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(),
1574-
valType.getShape());
1575-
return mfma8Layout;
1550+
auto valShape = valType.getShape();
1551+
LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape);
1552+
auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames());
1553+
StringAttr dimM = mfmaOutDims[0];
1554+
StringAttr dimN = mfmaOutDims[1];
1555+
1556+
auto swapLL = LinearLayout::empty();
1557+
// The rows are kept as is with an identity linear layout.
1558+
swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM);
1559+
// In transposed mfma32 layout, each thread holds 4 consecutive values along N
1560+
// dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11
1561+
// (owned by thread 0-31) every 16 columns to make each thread holds 8
1562+
// elements. This would mean exchange the 2nd and 3rd basis vector from an
1563+
// identity linear layout.
1564+
std::vector<std::vector<int32_t>> dimNBases(mfmaLL.getOutDimSizeLog2(dimN));
1565+
std::generate(dimNBases.begin(), dimNBases.end(),
1566+
[i = 0]() mutable { return std::vector<int32_t>{1 << i++}; });
1567+
std::swap(dimNBases[2], dimNBases[3]);
1568+
swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN});
1569+
1570+
return mfmaLL.compose(swapLL);
15761571
}
15771572

15781573
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,

python/src/passes.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ void init_triton_analysis(py::module &&m) {
2525

2626
void init_triton_passes_common(py::module &&m) {
2727
using namespace mlir;
28-
ADD_PASS_WRAPPER_0("add_strip_debug_info", createStripDebugInfoPass);
2928
ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass);
3029
ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass);
3130
ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass);

python/test/unit/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,16 @@ def fresh_knobs_except_libraries(monkeypatch):
138138
yield fresh_function()
139139
finally:
140140
reset_function()
141+
142+
143+
@pytest.fixture
144+
def with_allocator():
145+
import triton
146+
from triton.runtime._allocation import NullAllocator
147+
from triton._internal_testing import default_alloc_fn
148+
149+
triton.set_allocator(default_alloc_fn)
150+
try:
151+
yield
152+
finally:
153+
triton.set_allocator(NullAllocator())

0 commit comments

Comments
 (0)