Skip to content

Commit 268b414

Browse files
Merge commit 'c24aa15e30ebacd16799dadf1fe86d954ea5db97'
2 parents bdec54e + c24aa15 commit 268b414

File tree

6 files changed

+112
-69
lines changed

6 files changed

+112
-69
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 63 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -228,33 +228,11 @@ sharedToLinearLayoutAMDRotating(ArrayRef<int64_t> shape,
228228
return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape);
229229
}
230230

231-
} // namespace
232-
233-
LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
234-
NVMMASharedEncodingAttr shared,
231+
// Returns the layout of a single core matrix which tiles the nvmma layout
232+
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
235233
bool disableSwizzle) {
236-
MLIRContext *ctx = shared.getContext();
237-
int rank = shape.size();
238-
auto shapePerCTA = getShapePerCTA(shared, shape);
239-
if (rank == 1) {
240-
// TODO: Not sure if this is correct.
241-
return combineCtaCgaWithShape(
242-
LinearLayout::identity1D(shapePerCTA[0], S("offset"), S("dim0")),
243-
shared.getCTALayout(), shape);
244-
}
245-
// Construct bases for a the layout's 2-dimensional tile.
246-
assert(rank >= 2);
247-
int batchDims = rank - 2;
234+
auto *ctx = shared.getContext();
248235

249-
// Collapse all the outer dim into one. We will then create a layout for this
250-
// shape and reshape it to the original shape.
251-
std::array<int64_t, 2> collapsedShapePerCTA = {shapePerCTA[batchDims],
252-
shapePerCTA[batchDims + 1]};
253-
for (int i = 0; i < batchDims; i++)
254-
collapsedShapePerCTA[0] *= shapePerCTA[i];
255-
if (shared.getTransposed()) {
256-
std::swap(collapsedShapePerCTA[0], collapsedShapePerCTA[1]);
257-
}
258236
int elemBitWidth = shared.getElementBitWidth();
259237
int tileWidthBytes = shared.getSwizzlingByteWidth();
260238
int vec = 128 / elemBitWidth;
@@ -273,25 +251,9 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
273251

274252
int tileRows = 8;
275253
int tileCols = 8 * tileWidthBytes / elemBitWidth;
276-
bool isFp4Padded = false;
277-
if (auto sharedMMALayout =
278-
dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(shared)) {
279-
if (sharedMMALayout.getFp4Padded()) {
280-
isFp4Padded = true;
281-
}
282-
}
254+
bool isFp4Padded = shared.getFp4Padded();
283255
int packingFactor = isFp4Padded ? 2 : 1;
284256

285-
if (collapsedShapePerCTA[1] * packingFactor < tileCols ||
286-
collapsedShapePerCTA[0] < tileRows) {
287-
llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA to "
288-
"be at least ["
289-
<< tileRows << ", " << tileCols << "], collapsedShapePerCTA: ["
290-
<< collapsedShapePerCTA[0] << ", " << collapsedShapePerCTA[1]
291-
<< "]\n";
292-
llvm::report_fatal_error("Illegal shared layout");
293-
}
294-
295257
std::vector<std::vector<int>> bases2D;
296258
for (int col = 1; col < tileCols; col *= 2) {
297259
if (isFp4Padded) {
@@ -309,30 +271,75 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
309271
for (int row = 1; row < tileRows; row *= 2) {
310272
if (disableSwizzle) {
311273
bases2D.push_back({row, 0});
312-
continue;
313-
}
314-
if (isFp4Padded) {
274+
} else if (isFp4Padded) {
315275
int colPadded = vec * ((row / perPhase) % maxPhase);
316276
int colPacked = colPadded / 16 * 8 + colPadded % 8;
317277
bases2D.push_back({row, colPacked});
318278
} else {
319279
bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)});
320280
}
321281
}
282+
auto outDimNames = standardOutDimNames(ctx, 2);
283+
auto kRow = outDimNames[1];
284+
auto kCol = outDimNames[0];
285+
LinearLayout tileLayout =
286+
LinearLayout({{S("offset"), bases2D}}, {kRow, kCol});
287+
return tileLayout;
288+
}
289+
290+
} // namespace
291+
292+
LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
293+
NVMMASharedEncodingAttr shared,
294+
bool disableSwizzle) {
295+
MLIRContext *ctx = shared.getContext();
296+
int rank = shape.size();
297+
auto shapePerCTA = getShapePerCTA(shared, shape);
298+
if (rank == 1) {
299+
// TODO: Not sure if this is correct.
300+
return combineCtaCgaWithShape(
301+
LinearLayout::identity1D(shapePerCTA[0], S("offset"), S("dim0")),
302+
shared.getCTALayout(), shape);
303+
}
304+
// Construct bases for a the layout's 2-dimensional tile.
305+
assert(rank >= 2);
306+
int batchDims = rank - 2;
322307

323-
// Then distribute the remaining rows.
324-
for (int row = tileRows; row < collapsedShapePerCTA[0]; row *= 2) {
325-
bases2D.push_back({row, 0});
308+
// Collapse all the outer dim into one. We will then create a layout for this
309+
// shape and reshape it to the original shape.
310+
std::array<int64_t, 2> collapsedShapePerCTA{shapePerCTA[batchDims],
311+
shapePerCTA[batchDims + 1]};
312+
for (int i = 0; i < batchDims; i++)
313+
collapsedShapePerCTA[0] *= shapePerCTA[i];
314+
if (shared.getTransposed()) {
315+
std::swap(collapsedShapePerCTA[0], collapsedShapePerCTA[1]);
326316
}
327317

318+
auto tileLayout = getCoreMatrixLinearLayout(shared, disableSwizzle);
328319
auto outDimNames = standardOutDimNames(ctx, 2);
329-
std::reverse(outDimNames.begin(), outDimNames.end());
330-
LinearLayout tileLayout = LinearLayout({{S("offset"), bases2D}}, outDimNames);
331-
// Expand the layout to convert the whole shape per CTA.
332-
llvm::SmallDenseMap<StringAttr, int64_t> namedShape;
333-
namedShape[outDimNames[0]] = collapsedShapePerCTA[0];
334-
namedShape[outDimNames[1]] = collapsedShapePerCTA[1];
335-
tileLayout = ensureLayoutNotSmallerThan(tileLayout, namedShape);
320+
auto kRow = outDimNames[1];
321+
auto kCol = outDimNames[0];
322+
auto tileRows = tileLayout.getOutDimSize(kRow);
323+
auto tileCols = tileLayout.getOutDimSize(kCol);
324+
325+
int packingFactor = shared.getFp4Padded() ? 2 : 1;
326+
if (collapsedShapePerCTA[1] * packingFactor < tileCols ||
327+
collapsedShapePerCTA[0] < tileRows) {
328+
llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA to "
329+
"be at least ["
330+
<< tileRows << ", " << (tileCols / packingFactor)
331+
<< "], collapsedShapePerCTA: [" << collapsedShapePerCTA[0]
332+
<< ", " << collapsedShapePerCTA[1] << "]\n";
333+
llvm::report_fatal_error("Illegal shared layout");
334+
}
335+
336+
// Distribute the remaining rows and cols.
337+
auto kOffset = S("offset");
338+
auto layout = tileLayout;
339+
layout *= LinearLayout::identity1D(collapsedShapePerCTA[0] / tileRows,
340+
kOffset, kRow);
341+
layout *= LinearLayout::identity1D(collapsedShapePerCTA[1] / tileCols,
342+
kOffset, kCol);
336343

337344
// Reshape the layout to the N-D pre-transposed shape per CTA.
338345
SmallVector<int64_t> maybeTransposedShapePerCTA = shapePerCTA;
@@ -344,8 +351,7 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
344351
maybeTransposedShapePerCTA.begin() + 1,
345352
maybeTransposedShapePerCTA.end());
346353
}
347-
auto reshapedLayout =
348-
reshapeLayout(ctx, tileLayout, maybeTransposedShapePerCTA);
354+
auto reshapedLayout = reshapeLayout(ctx, layout, maybeTransposedShapePerCTA);
349355

350356
if (shared.getTransposed()) {
351357
SmallVector<int> order = {rank - 1};

python/test/unit/conftest.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pathlib
44
import pytest
55
import tempfile
6+
from typing import Optional, Set
67

78

89
def pytest_configure(config):
@@ -76,20 +77,50 @@ def fresh_triton_cache():
7677
os.environ.pop("TRITON_CACHE_DIR", None)
7778

7879

79-
@pytest.fixture
80-
def fresh_knobs(request, monkeypatch):
80+
def _fresh_knobs_impl(monkeypatch, skipped_attr: Optional[Set[str]] = None):
8181
from triton import knobs
82+
83+
if skipped_attr is None:
84+
skipped_attr = set()
85+
8286
knobs_map = {
8387
name: knobset
8488
for name, knobset in knobs.__dict__.items()
85-
if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs
89+
if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
8690
}
87-
try:
91+
92+
def fresh_function():
8893
for name, knobset in knobs_map.items():
8994
setattr(knobs, name, knobset.copy().reset())
9095
for knob in knobset.knob_descriptors.values():
9196
monkeypatch.delenv(knob.key, raising=False)
92-
yield knobs
93-
finally:
97+
return knobs
98+
99+
def reset_function():
94100
for name, knobset in knobs_map.items():
95101
setattr(knobs, name, knobset)
102+
103+
return fresh_function, reset_function
104+
105+
106+
@pytest.fixture
107+
def fresh_knobs(monkeypatch):
108+
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
109+
try:
110+
yield fresh_function()
111+
finally:
112+
reset_function()
113+
114+
115+
@pytest.fixture
116+
def fresh_knobs_except_libraries(monkeypatch):
117+
"""
118+
A variant of `fresh_knobs` that keeps library path
119+
information from the environment as these may be
120+
needed to successfully compile kernels.
121+
"""
122+
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch, skipped_attr={"build", "nvidia", "amd"})
123+
try:
124+
yield fresh_function()
125+
finally:
126+
reset_function()

python/test/unit/runtime/test_compilation_listener.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def cumsum_kernel(ptr):
1717
tl.store(block, tl.cumsum(x, 0))
1818

1919

20-
def test_compile_stats(device: str, fresh_knobs: Any, fresh_triton_cache: str) -> None:
20+
def test_compile_stats(device: str, fresh_knobs_except_libraries: Any, fresh_triton_cache: str) -> None:
2121
captured: Union[tuple[Union[ASTSource, IRSource], dict[str, Any], CompileTimes, bool], None] = None
2222

2323
def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, Any], times: CompileTimes,
@@ -26,7 +26,7 @@ def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, Any],
2626
assert captured is None
2727
captured = (src, metadata, times, cache_hit)
2828

29-
fresh_knobs.compilation.listener = compile_listener
29+
fresh_knobs_except_libraries.compilation.listener = compile_listener
3030

3131
x = torch.randn(4, device=device)
3232
cumsum_kernel[(1, )](x)

python/test/unit/test_knobs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
import shutil
44
import triton
5+
from triton._internal_testing import is_hip
56

67
from pathlib import Path
78

@@ -136,6 +137,7 @@ def test_read_env(truthy, falsey, fresh_knobs, monkeypatch):
136137
assert fresh_knobs.cache.override_dir == "/tmp/triton_home/.triton/override"
137138

138139
from triton.runtime.cache import FileCacheManager
140+
139141
assert fresh_knobs.cache.manager_class == FileCacheManager
140142

141143
assert fresh_knobs.build.backend_dirs == {"/tmp/cuda/crt", "/tmp/cuda/rt"}
@@ -216,8 +218,12 @@ class TestManagerClass(FileCacheManager):
216218
assert fresh_knobs.cache.manager_class == FileCacheManager
217219

218220

221+
@pytest.mark.skipif(
222+
is_hip(),
223+
reason="PTXAS is not installed on AMD",
224+
)
219225
def test_nvidia_tool(fresh_knobs, tmp_path, monkeypatch):
220-
triton_root = Path(__file__).parent.parent.parent / "triton"
226+
triton_root = Path(fresh_knobs.__file__).parent
221227
default_ptxas = triton_root / "backends/nvidia/bin/ptxas"
222228

223229
assert default_ptxas.exists()

python/triton/knobs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __set_name__(self, objclass: Type[object], name: str) -> None:
4646

4747
def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType:
4848
if obj is None:
49-
raise AttributeError("Cannot access {type(self)} on non-instance")
49+
raise AttributeError(f"Cannot access {type(self)} on non-instance")
5050

5151
if self.name in obj.__dict__:
5252
return self.transform(obj.__dict__[self.name])
@@ -311,7 +311,7 @@ def copy(self: knobs_type) -> knobs_type:
311311
return res
312312

313313
def reset(self: knobs_type) -> knobs_type:
314-
for knob in self.knobs.keys():
314+
for knob in self.knob_descriptors.keys():
315315
delattr(self, knob)
316316
return self
317317

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ def make_ttgir(mod, metadata, opt, capability):
277277
passes.common.add_cse(pm)
278278
passes.common.add_symbol_dce(pm)
279279
if capability // 10 >= 9:
280-
nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
281280
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
281+
nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
282282
passes.common.add_canonicalizer(pm)
283283
pm.run(mod)
284284
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)

0 commit comments

Comments
 (0)