Skip to content

Commit fbe22ae

Browse files
committed
Merge commit 'ff5c1e77ce8064501d9f260f5a14de195d74425f'
2 parents 6f41a1d + ff5c1e7 commit fbe22ae

File tree

43 files changed

+999
-989
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+999
-989
lines changed

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ test-interpret: all
7373

7474
.PHONY: test-proton
7575
test-proton: all
76-
$(PYTEST) --tb=short -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py
76+
$(PYTEST) --tb=short -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py -k "not test_overhead"
7777
$(PYTEST) --tb=short -s third_party/proton/test/test_override.py
78+
$(PYTEST) --tb=short -s third_party/proton/test/test_instrumentation.py::test_overhead
7879

7980
.PHONY: test-python
8081
test-python: test-unit test-regression test-interpret test-proton

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2806,15 +2806,34 @@ struct TritonGPUInferLayoutInterface
28062806
mlir::dyn_cast<triton::gpu::DotOperandEncodingAttr>(operandEncodingB);
28072807
if (!aEncoding && !bEncoding)
28082808
return mlir::success();
2809-
auto mmaAEncoding =
2810-
mlir::dyn_cast_or_null<NvidiaMmaEncodingAttr>(aEncoding.getParent());
2811-
if (mmaAEncoding && mmaAEncoding.isHopper())
2812-
return success();
2813-
// Verify that the encodings are valid.
28142809
if (!aEncoding || !bEncoding)
28152810
return op->emitError("mismatching encoding between A and B operands");
2811+
// Verify that the encodings are valid.
28162812
if (aEncoding.getKWidth() != bEncoding.getKWidth())
28172813
return op->emitError("mismatching kWidth between A and B operands");
2814+
2815+
// Check if we have already selected an MMA version for Nvidia. If so,
2816+
// validate that the encodings are correct and compatible.
2817+
auto mmaAEncoding =
2818+
dyn_cast_or_null<NvidiaMmaEncodingAttr>(aEncoding.getParent());
2819+
auto mmaBEncoding =
2820+
dyn_cast_or_null<NvidiaMmaEncodingAttr>(bEncoding.getParent());
2821+
auto dotOp = cast<DotOp>(op);
2822+
auto resEnc = dotOp.getResult().getType().getEncoding();
2823+
auto mmaResEncoding = dyn_cast<NvidiaMmaEncodingAttr>(resEnc);
2824+
if (mmaAEncoding || mmaBEncoding || mmaResEncoding) {
2825+
// Check that they are all set and have the same version.
2826+
if (!mmaAEncoding || !mmaBEncoding || !mmaResEncoding)
2827+
return op->emitError("mismatching MMA encoding");
2828+
auto mmaBEncoding = cast<NvidiaMmaEncodingAttr>(bEncoding.getParent());
2829+
if (mmaAEncoding.getVersionMajor() != mmaBEncoding.getVersionMajor() ||
2830+
mmaAEncoding.getVersionMajor() != mmaResEncoding.getVersionMajor()) {
2831+
return op->emitError("mismatched MMA version.");
2832+
}
2833+
// Verify that the operands are supported on the selected MMA version.
2834+
if (!supportMMA(dotOp, mmaResEncoding.getVersionMajor()))
2835+
return op->emitError("unsupported MMA version");
2836+
}
28182837
return success();
28192838
}
28202839

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 4 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -470,93 +470,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
470470
return combineCtaCgaWithShape(tileLayout, getCTALayout(), shape);
471471
}
472472

473-
std::optional<LinearLayout>
474-
chooseLLDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
475-
int32_t elemBitWidth, unsigned instBitWidth,
476-
unsigned numLanesInShuffleGroup) {
477-
using BaseTy = std::vector<std::vector<int32_t>>;
478-
// This function will derive the layout for the ds_read_tr instruction
479-
// based on the input layout (LL/DotLayout/...)
480-
// The ds_read_tr instruction works on instBitWidth per lane and in groups of
481-
// numLanesInShuffleGroup lanes.
482-
483-
// In this example we look at ds_read_b64_tr (instBitWidth = 64) and
484-
// numLanesInShuffleGroup = 16 with 64 lanes per warp. Using M-continuous
485-
// 16-bit input tensor A as an example. Each lane will load 4 consecutive
486-
// elements (64-bit in total) along M. There are 4 consecutive lanes in total
487-
// along M. Then the loaded elements are exchanged within the MxK=16x4 "base
488-
// unit".
489-
// K0 K1 K2 K3
490-
// +---+---+---+---+
491-
// M0 | | | | | M0, K[0-3]: T0
492-
// M1 | T | T | T | T | M1, K[0-3]: T1
493-
// M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2
494-
// M3 | | | | | M3, K[0-3]: T3
495-
// +---+---+---+---+
496-
// M4 | | | | | M4, K[0-3]: T4
497-
// M5 | T | T | T | T | M5, K[0-3]: T5
498-
// M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6
499-
// M7 | | | | | M7, K[0-3]: T7
500-
// +---+---+---+---+ ==>
501-
// M8 | | | | | M8, K[0-3]: T8
502-
// M9 | T | T | T | T | M9, K[0-3]: T9
503-
// M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10
504-
// M11 | | | | | M11, K[0-3]: T11
505-
// +---+---+---+---+
506-
// M12 | | | | | M12, K[0-3]: T12
507-
// M13 | T | T | T | T | M13, K[0-3]: T13
508-
// M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14
509-
// M15 | | | | | M15, K[0-3]: T15
510-
// +---+---+---+---+
511-
512-
// Given the layout represented by `enc` and shape, we can derive the layout
513-
// that ds_read_b64_tr need to have in order to perform a vectorized load of
514-
// the elements. This can be done by rearranging the inner 4x16 element base
515-
// unit in the LL by rearranging the first numReg register bases and the
516-
// first numLane lane bases.
517-
auto rotatePrefixes = [](BaseTy &regBase, std::size_t numReg,
518-
BaseTy &laneBase, std::size_t numLane) {
519-
// Concatenate prefixes of the two vectors. Lane first and then regs.
520-
// C D E F | A B
521-
// Then copy over numReg to the regBase and numLane to laneBase
522-
// C D | E F A B
523-
BaseTy baseUnit(laneBase.begin(), laneBase.begin() + numLane);
524-
llvm::append_range(
525-
baseUnit, llvm::make_range(regBase.begin(), regBase.begin() + numReg));
526-
527-
std::copy(baseUnit.begin(), baseUnit.begin() + numReg, regBase.begin());
528-
std::copy(baseUnit.begin() + numReg, baseUnit.end(), laneBase.begin());
529-
};
530-
531-
auto ctx = enc.getContext();
532-
assert(elemBitWidth == 8 || elemBitWidth == 16);
533-
// Get how many reg bases and tile bases the ds_read_tr tile spans
534-
unsigned numRegBases = llvm::Log2_32(instBitWidth / elemBitWidth);
535-
unsigned numLaneBases = llvm::Log2_32(numLanesInShuffleGroup);
536-
537-
auto ldsTransLayout = triton::gpu::toLinearLayout(shape, enc);
538-
auto bases = ldsTransLayout.getBases();
539-
auto kRegister = S("register");
540-
auto kLane = S("lane");
541-
542-
// Make sure that we have enough register bases to rotate, otherwise we
543-
// can't return a valid ds_read_tr layout
544-
if (ldsTransLayout.getInDimSizeLog2(kRegister) < numRegBases) {
545-
return std::nullopt;
546-
}
547-
// We should always have enough lanes
548-
assert(ldsTransLayout.getInDimSizeLog2(kLane) >= numLaneBases);
549-
rotatePrefixes(bases[kRegister], numRegBases, bases[kLane], numLaneBases);
550-
// Scale types double the elements for a total of 16 vgpr (still only 16
551-
// elements contiguous). Need to adjust the lane basis to reflect that
552-
if (elemBitWidth == 8 && numLanesInShuffleGroup == 8) {
553-
assert(ldsTransLayout.getInDimSizeLog2(kLane) >= (numLaneBases + 1));
554-
std::swap(bases[kLane][numLaneBases - 1], bases[kLane][numLaneBases]);
555-
}
556-
557-
return LinearLayout(bases, ldsTransLayout.getOutDims(), false);
558-
}
559-
560473
std::optional<LinearLayout>
561474
chooseDotDsReadTrLayout(DotOperandEncodingAttr dotMfmaLayout,
562475
ArrayRef<int64_t> shape, int32_t elemBitWidth,
@@ -1461,14 +1374,10 @@ std::optional<LinearLayout>
14611374
chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
14621375
int32_t elemBitWidth, unsigned instBitWidth,
14631376
unsigned numLanesInShuffleGroup) {
1464-
if (elemBitWidth == 4) {
1465-
auto dot = cast<DotOperandEncodingAttr>(enc);
1466-
return chooseDotDsReadTrLayout(dot, shape, elemBitWidth, instBitWidth,
1467-
numLanesInShuffleGroup);
1468-
} else {
1469-
return chooseLLDsReadTrLayout(enc, shape, elemBitWidth, instBitWidth,
1470-
numLanesInShuffleGroup);
1471-
}
1377+
assert(elemBitWidth == 4);
1378+
auto dot = cast<DotOperandEncodingAttr>(enc);
1379+
return chooseDotDsReadTrLayout(dot, shape, elemBitWidth, instBitWidth,
1380+
numLanesInShuffleGroup);
14721381
}
14731382

14741383
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/BuiltinTypes.h"
2626
#include "mlir/IR/Diagnostics.h"
2727
#include "mlir/Support/LLVM.h"
28+
#include "triton/Analysis/Utility.h"
2829
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
2930
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
3031
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
@@ -91,13 +92,12 @@ LogicalResult WarpGroupDotOp::verify() {
9192
if (retShapePerCTA[1] % 8 != 0)
9293
return emitOpError("WGMMA result N dimension must be divisible by 8");
9394

94-
auto aElemTy = getA().getType().getElementType();
95-
if (!(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
96-
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
97-
aElemTy.isF32()))
98-
return emitOpError("WGMMA result element type must be F16, BF16, F32, "
99-
"F8E5M2, F8E4M3FN, or integer type");
95+
// Verify MMA version is supported for operands.
96+
int mmaVersion = nvmmaEnc.getVersionMajor();
97+
if (!supportMMA(getA(), mmaVersion) || !supportMMA(getB(), mmaVersion))
98+
return emitOpError("unsupported MMA version for the given operands");
10099

100+
auto aElemTy = getA().getType().getElementType();
101101
if (getMaxNumImpreciseAcc() < 32 &&
102102
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
103103
resTy.getElementType().isF32()) {

python/src/gluon_ir.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ struct GluonLayouts {
105105
py::handle DistributedLinearLayout;
106106
py::handle DotOperandLayout;
107107
py::handle NVMMADistributedLayout;
108+
py::handle TensorMemoryScalesLayout;
109+
py::handle TensorMemoryLayout;
108110
py::handle NVMMASharedLayout;
109111
py::handle SwizzledSharedLayout;
110112
py::handle SharedLinearLayout;
@@ -120,6 +122,8 @@ struct GluonLayouts {
120122
py::module::import("triton.experimental.gluon.language.amd._layouts");
121123
auto intelLayouts =
122124
py::module::import("triton.experimental.gluon.language.intel._layouts");
125+
auto blackwellLayouts = py::module::import(
126+
"triton.experimental.gluon.language.nvidia.blackwell");
123127
AutoLayout = py::object(layouts.attr("AutoLayout")).release();
124128
BlockedLayout = py::object(layouts.attr("BlockedLayout")).release();
125129
SliceLayout = py::object(layouts.attr("SliceLayout")).release();
@@ -128,6 +132,10 @@ struct GluonLayouts {
128132
DotOperandLayout = py::object(layouts.attr("DotOperandLayout")).release();
129133
NVMMADistributedLayout =
130134
py::object(layouts.attr("NVMMADistributedLayout")).release();
135+
TensorMemoryScalesLayout =
136+
py::object(blackwellLayouts.attr("TensorMemoryScalesLayout")).release();
137+
TensorMemoryLayout =
138+
py::object(blackwellLayouts.attr("TensorMemoryLayout")).release();
131139
NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release();
132140
SwizzledSharedLayout =
133141
py::object(layouts.attr("SwizzledSharedLayout")).release();
@@ -268,6 +276,15 @@ py::object layoutToGluon(Attribute layout) {
268276
intelDpas.getExecutionSize(), intelDpas.getOpsPerChannel(),
269277
toStdVector(intelDpas.getWarpsPerCTA()),
270278
toStdVector(intelDpas.getRepCluster()), intelDpas.getThreadsPerWarp());
279+
} else if (auto tmemScales =
280+
dyn_cast<ttng::TensorMemoryScalesEncodingAttr>(layout)) {
281+
return layouts.TensorMemoryScalesLayout(std::vector<unsigned>{
282+
tmemScales.getCTASplitM(), tmemScales.getCTASplitN()});
283+
} else if (auto tmem = dyn_cast<ttng::TensorMemoryEncodingAttr>(layout)) {
284+
return layouts.TensorMemoryLayout(
285+
std::vector<unsigned>{tmem.getBlockM(), tmem.getBlockN()},
286+
tmem.getColStride(),
287+
std::vector<unsigned>{tmem.getCTASplitM(), tmem.getCTASplitN()});
271288
}
272289

273290
throw py::value_error("Unhandled encoding encountered");

python/test/unit/tools/test_triton_to_gluon.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
def convert_kernel(kernel, kernel_name, tmp_path):
16-
converted = convert_triton_to_gluon(kernel)
16+
converted = convert_triton_to_gluon([kernel])
1717

1818
# Write converted kernel to a file so @gluon.jit can retrieve source
1919
mod_path = tmp_path / "converted_kernel.py"
@@ -52,7 +52,7 @@ def test_simple_kernel(tmp_path):
5252
ref = torch.empty_like(x)
5353
add_kernel[grid](x, y, ref, n, BLOCK)
5454

55-
torch.testing.assert_close(out, ref)
55+
torch.testing.assert_close(out, ref, atol=0, rtol=0)
5656

5757

5858
@triton.jit
@@ -85,7 +85,7 @@ def test_triton_to_gluon_dot_minimal(tmp_path):
8585

8686
ref = torch.empty_like(c)
8787
matmul_tile_kernel[grid](a, b, ref, M, N, K, num_warps=8)
88-
torch.testing.assert_close(c, ref)
88+
torch.testing.assert_close(c, ref, atol=0, rtol=0)
8989

9090

9191
@triton.jit
@@ -153,7 +153,7 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
153153
ref = torch.empty_like(output)
154154
matmul_kernel[grid](a, b, ref, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0),
155155
output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K)
156-
torch.testing.assert_close(output, ref)
156+
torch.testing.assert_close(output, ref, atol=0, rtol=0)
157157

158158

159159
@triton.jit
@@ -177,7 +177,7 @@ def test_triton_to_gluon_descriptor_roundtrip(tmp_path):
177177
y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16)
178178
desc_ref = TensorDescriptor(y_ref, y_ref.shape, y_ref.stride(), block_shape)
179179
descriptor_store_kernel[grid](desc_ref, M, N, 1.0)
180-
torch.testing.assert_close(y, y_ref)
180+
torch.testing.assert_close(y, y_ref, atol=0, rtol=0)
181181

182182

183183
@triton.jit
@@ -204,7 +204,7 @@ def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path):
204204
y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16)
205205
desc_ref = TensorDescriptor(y_ref, y_ref.shape, y_ref.stride(), block_shape)
206206
descriptor_copy_kernel[grid](in_desc, desc_ref, M, N)
207-
torch.testing.assert_close(y, y_ref)
207+
torch.testing.assert_close(y, y_ref, atol=0, rtol=0)
208208

209209

210210
@triton.jit
@@ -232,7 +232,7 @@ def test_triton_reshape_trans(tmp_path):
232232
kernel[grid](x, y, out, n, BLOCK)
233233
ref = torch.empty_like(x)
234234
reshape_trans_kernel[grid](x, y, ref, n, BLOCK)
235-
torch.testing.assert_close(out, ref)
235+
torch.testing.assert_close(out, ref, atol=0, rtol=0)
236236

237237

238238
BLOCK_SPLIT = tl.constexpr(256)
@@ -262,7 +262,7 @@ def test_split(tmp_path):
262262
kernel[grid](x, out)
263263
ref = torch.empty_like(x[:n])
264264
split_kernel[grid](x, ref)
265-
torch.testing.assert_close(out, ref)
265+
torch.testing.assert_close(out, ref, atol=0, rtol=0)
266266

267267

268268
@triton.jit
@@ -281,4 +281,23 @@ def test_reduce_to_scalar(tmp_path):
281281
kernel[grid](out)
282282
ref = torch.empty_like(out)
283283
reduce_to_scalar_kernel[grid](ref)
284-
torch.testing.assert_close(out, ref)
284+
torch.testing.assert_close(out, ref, atol=0, rtol=0)
285+
286+
287+
@triton.jit
288+
def num_threads_kernel(out_ptr):
289+
num_threads: tl.constexpr = tl.extra.cuda.num_threads()
290+
offs = tl.arange(0, num_threads)
291+
tl.store(out_ptr + offs, 1)
292+
293+
294+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
295+
def test_num_threads(tmp_path):
296+
kernel = convert_kernel(num_threads_kernel, "num_threads_kernel", tmp_path)
297+
298+
num_threads = 256
299+
out = torch.empty(num_threads, dtype=torch.int32, device="cuda")
300+
kernel[(1, )](out, num_warps=num_threads // 32)
301+
ref = torch.empty_like(out)
302+
num_threads_kernel[(1, )](ref, num_warps=num_threads // 32)
303+
torch.testing.assert_close(out, ref, atol=0, rtol=0)

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ def _realize_cta_layout(layout, rank):
99
ctas_per_cga = layout.ctas_per_cga or [1] * rank
1010
cta_split_num = layout.cta_split_num or [1] * rank
1111
cta_order = layout.cta_order or list(reversed(range(rank)))
12+
# Canonicalize CTA order to [n,n-1,...,0] if CTAsPerCGA is [1...1]. This matches logic in C++.
13+
if all(num_cta == 1 for num_cta in ctas_per_cga):
14+
cta_order = list(range(rank - 1, -1, -1))
1215
object.__setattr__(layout, "ctas_per_cga", ctas_per_cga)
1316
object.__setattr__(layout, "cta_split_num", cta_split_num)
1417
object.__setattr__(layout, "cta_order", cta_order)

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,11 @@ def _check_same_layout(xs):
416416
_check(all(l == l0 for l in layouts[1:]),
417417
lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
418418

419+
def _store_legacy(self, ptr, val, mask, boundary_check, cache, eviction):
420+
if ptr.type.is_block() and not val.type.is_block():
421+
val = self.splat(val, ptr.type.get_block_shapes(), ptr.type.layout)
422+
return super()._store_legacy(ptr, val, mask, boundary_check, cache, eviction)
423+
419424
def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
420425
reverse: bool) -> Tuple[TensorTy, ...]:
421426
shape = inputs[0].type.shape

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def mangle(self) -> str:
6868
cta_split_str = (f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else "")
6969
return f"TL{block_str}{stride_str}{cta_split_str}TL"
7070

71+
def __hash__(self):
72+
return hash((self.block, self.col_stride, self.cta_split_num))
73+
7174

7275
@dataclass(frozen=True, eq=True)
7376
class TensorMemoryScalesLayout:
@@ -91,6 +94,9 @@ def mangle(self) -> str:
9194
cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
9295
return f"TLS{cta_split_str}TLS"
9396

97+
def __hash__(self):
98+
return hash(self.cta_split_num)
99+
94100

95101
@constexpr_function
96102
def get_tmem_reg_layout(

python/triton/experimental/gluon/language/nvidia/hopper/tma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def make_tensor_descriptor(
116116
_semantic=None,
117117
) -> tensor_descriptor:
118118
padding_option = _unwrap_if_constexpr(padding_option)
119+
block_shape = _unwrap_if_constexpr(block_shape)
119120

120121
ndim = len(shape)
121122
if not (1 <= ndim <= 5):

0 commit comments

Comments
 (0)