Skip to content

Commit 2c10050

Browse files
Merge commit '9743ec0dca5bbd9dbce20adc3ee273af6b095f94'
2 parents 541ff05 + 9743ec0 commit 2c10050

File tree

26 files changed

+1013
-290
lines changed

26 files changed

+1013
-290
lines changed

include/triton/Analysis/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
231231
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
232232
RankedTensorType dstTy);
233233

234+
// Check if MFMA layout can be converted to the dot operand
235+
// layout using warp shuffle.
236+
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
237+
RankedTensorType dstTy);
238+
234239
// TODO: Move utility functions that belong to ConvertLayoutOp to class
235240
// ConvertLayoutOpHelper in the future
236241
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);

lib/Analysis/Utility.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/IR/Dialect.h"
1212
#include "mlir/IR/Matchers.h"
1313
#include "mlir/Support/LLVM.h"
14+
#include "triton/Conversion/MLIRTypes.h"
1415
#include "triton/Dialect/Triton/IR/Dialect.h"
1516
#include "triton/Dialect/Triton/IR/Utility.h"
1617
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -650,6 +651,25 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
650651
return ans;
651652
}
652653

654+
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
655+
RankedTensorType dstTy) {
656+
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
657+
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
658+
if (!mfmaLayout || !dotOperandLayout)
659+
return false;
660+
661+
// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
662+
return dotOperandLayout.getParent() == mfmaLayout &&
663+
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
664+
dotOperandLayout.getKWidth() == 8 &&
665+
getContigPerThread(mfmaLayout)[1] == 4 &&
666+
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
667+
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
668+
triton::type::isFloat8(srcTy.getElementType()) &&
669+
triton::type::isFloat8(dstTy.getElementType()) &&
670+
mfmaLayout.getWarpsPerCTA()[1] == 1;
671+
}
672+
653673
// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
654674
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
655675
// have a transformation that's the identity on kBlock, we don't need to use
@@ -749,7 +769,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
749769
return !cvtReordersRegisters(srcTy, dstTy) &&
750770
!triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) &&
751771
!isBlockedToDotShortcut(srcTy, dstTy) &&
752-
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
772+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
773+
// to be removed when generalized warp shuffle conversions
774+
// are ready:
775+
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
753776
}
754777

755778
bool atomicNeedsSharedMemory(Value value) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
402402
return failure();
403403
}
404404

405+
// The following check can be removed when generalized warp shuffle
406+
// conversions are ready:
407+
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
408+
return failure();
409+
}
410+
405411
assert(cvtNeedsSharedMemory(srcTy, dstTy));
406412

407413
SmallVector<Value> inVals =

python/src/ir.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ void init_triton_ir(py::module &&m) {
605605
"Function argument index out of range");
606606
return self.getArgument(idx);
607607
})
608+
.def("get_num_args", &FuncOp::getNumArguments)
608609
.def(
609610
"add_entry_block",
610611
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },

python/test/unit/language/test_compile_errors.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def kernel():
1717
a += 1 # noqa
1818

1919
with pytest.raises(CompilationError) as e:
20-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
20+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
2121

2222
try:
2323
assert "is not defined" in str(e.value), "error should mention the undefined variable"
@@ -32,7 +32,7 @@ def kernel():
3232
0 + "a"
3333

3434
with pytest.raises(CompilationError) as e:
35-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
35+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
3636

3737
try:
3838
assert "at 2:4:" in str(e.value), "error should point to the 0"
@@ -47,7 +47,7 @@ def kernel():
4747
tl.static_assert(isinstance(0, tl.tensor))
4848

4949
with pytest.raises(CompilationError) as e:
50-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
50+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
5151

5252
try:
5353
assert isinstance(e.value, CompileTimeAssertionFailure)
@@ -66,7 +66,7 @@ def kernel():
6666
not (0, 0)
6767

6868
with pytest.raises(CompilationError) as e:
69-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
69+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
7070

7171
try:
7272
assert e.value.__cause__ is None
@@ -83,7 +83,7 @@ def kernel():
8383
1.0 << 1
8484

8585
with pytest.raises(CompilationError) as e:
86-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
86+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
8787

8888
try:
8989
assert "at 2:4:" in str(e.value), "error should point to the 1.0"
@@ -107,7 +107,7 @@ def kernel():
107107
nested_call()
108108

109109
with pytest.raises(CompilationError) as e:
110-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
110+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
111111

112112
try:
113113
inner = e.value.__cause__
@@ -130,7 +130,7 @@ def kernel():
130130
tl.expand_dims(None, -1)
131131

132132
with pytest.raises(CompilationError) as e:
133-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
133+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
134134

135135
try:
136136
inner = e.value.__cause__
@@ -157,7 +157,7 @@ def kernel():
157157
a = two_returns()
158158
a + tl.arange(0, 4) # only works if we took the first return
159159

160-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
160+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
161161

162162

163163
def test_not_const_annotate_no_err():
@@ -166,7 +166,7 @@ def test_not_const_annotate_no_err():
166166
def kernel(N: int = 1):
167167
pass
168168

169-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
169+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))
170170

171171

172172
@triton.jit
@@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr):
186186
a = returns_branched_on_constexpr(N)
187187
a + tl.arange(0, 4)
188188

189-
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0}))
189+
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0}))
190190

191191
@triton.jit
192192
def kernel2(N: tl.constexpr):
193193
a = returns_branched_on_constexpr(N)
194194
a + tl.arange(0, 8)
195195

196-
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1}))
196+
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1}))
197197

198198

199199
@triton.jit
@@ -211,7 +211,7 @@ def kernel(N: int):
211211
returns_branched_on_non_constexpr(N)
212212

213213
with pytest.raises(CompilationError) as e:
214-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
214+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))
215215

216216
try:
217217
assert "at 2:4:" in str(e.value), "error should point to the function call"
@@ -227,7 +227,7 @@ def kernel():
227227
tl.arange(2, 7)
228228

229229
with pytest.raises(CompilationError) as e:
230-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
230+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
231231
assert str(e.value.__cause__) == "arange's range must be a power of 2"
232232

233233

@@ -238,7 +238,7 @@ def kernel():
238238
tl.full((33, ), 0, dtype=tl.int64)
239239

240240
with pytest.raises(CompilationError) as e:
241-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
241+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
242242
assert str(e.value.__cause__) == "Shape element 0 must be a power of 2"
243243

244244

@@ -251,7 +251,7 @@ def kernel():
251251
a = CAPTURED # noqa
252252

253253
with pytest.raises(CompilationError) as e:
254-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
254+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
255255
assert "CAPTURED is not defined" in str(e.value)
256256

257257

@@ -265,7 +265,7 @@ def kernel():
265265
a = GLOBAL # noqa
266266

267267
with pytest.raises(CompilationError) as e:
268-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
268+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
269269
assert "global variable" in str(e.value)
270270

271271

@@ -279,7 +279,7 @@ def kernel():
279279
a = CONSTEXPR_ANNOTATED_GLOBAL # noqa
280280

281281
# No error.
282-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
282+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
283283

284284

285285
CONSTEXPR_GLOBAL = tl.constexpr(42)
@@ -292,7 +292,7 @@ def kernel():
292292
a = CONSTEXPR_GLOBAL # noqa
293293

294294
# No error.
295-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
295+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
296296

297297

298298
TYPE_ALIAS = tl.pointer_type(tl.int32)
@@ -305,7 +305,7 @@ def kernel():
305305
a = TYPE_ALIAS # noqa
306306

307307
# No error.
308-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
308+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
309309

310310

311311
def test_global_access_in_fn_default_arg():
@@ -315,7 +315,7 @@ def kernel(a=GLOBAL):
315315
pass
316316

317317
# No error.
318-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={}))
318+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={}))
319319

320320

321321
def test_defaults_assign_no_err():
@@ -324,7 +324,7 @@ def test_defaults_assign_no_err():
324324
def kernel(a=1, B: tl.constexpr = ""):
325325
pass
326326

327-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""}))
327+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""}))
328328

329329

330330
def test_where_warning(fresh_triton_cache):
@@ -337,7 +337,7 @@ def kernel():
337337
tl.where(a, b, c)
338338

339339
with pytest.warns(UserWarning):
340-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
340+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
341341

342342

343343
@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
@@ -371,7 +371,8 @@ def dtype_kernel(dtype: tl.constexpr):
371371
ctx = pytest.raises(CompilationError, match="")
372372

373373
with ctx as e:
374-
triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype}))
374+
triton.compile(
375+
triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype}))
375376

376377
if dtype not in supported_dtypes:
377378
try:
@@ -390,7 +391,7 @@ def dot_kernel():
390391
tl.dot(a, b, max_num_imprecise_acc=128)
391392

392393
with pytest.raises(CompilationError) as e:
393-
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={}))
394+
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={}))
394395
try:
395396
assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)")
396397
except AssertionError as assertion_err:

python/test/unit/language/test_core.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4407,15 +4407,17 @@ def kernel(x):
44074407
def test_value_specialization(value: int, value_type: str, device) -> None:
44084408

44094409
def repr(specialization):
4410-
spec_type = specialization.signature["VALUE"]
4411-
return f"kernel_{spec_type}"
4410+
ty = specialization.signature["value1"]
4411+
cst = '_'.join([k for k, v in specialization.constants.items() if v == 1])
4412+
return f"kernel_{ty}_{cst}"
44124413

44134414
@triton.jit(repr=repr)
4414-
def kernel(VALUE, X):
4415+
def kernel(value1, is_one, X):
44154416
pass
44164417

44174418
x = torch.tensor([3.14159], device=device)
4418-
h = kernel[(1, )](value, x)
4419+
h = kernel[(1, )](value, 1, x)
4420+
assert "is_one" in h.name
44194421
assert value_type in h.name
44204422

44214423

@@ -6186,6 +6188,19 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r
61866188
torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32))
61876189

61886190

6191+
def test_dtype(device):
6192+
6193+
@triton.jit
6194+
def kernel(X):
6195+
dtype_x: tl.constexpr = X.dtype.element_ty
6196+
tl.static_assert(dtype_x == tl.int32)
6197+
tl.static_assert(dtype_x == tl.constexpr(tl.int32))
6198+
tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32))
6199+
6200+
X = torch.zeros(1, dtype=torch.int32, device=device)
6201+
kernel[(1, )](X)
6202+
6203+
61896204
def test_side_effectful_scan(device):
61906205
if device != "cuda":
61916206
pytest.xfail()

python/test/unit/language/test_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def kernel():
2323
pass
2424

2525
try:
26-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
26+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
2727
except Exception as e:
2828
pytest.fail(f"triton compile failed with error: {e}")
2929

0 commit comments

Comments
 (0)