Skip to content

Commit b669c0d

Browse files
Merge commit '4dfdc32ff291c2ffab8f20135e94dadd47a9a0cc'
2 parents 081ae01 + 4dfdc32 commit b669c0d

File tree

19 files changed

+390
-162
lines changed

19 files changed

+390
-162
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,6 @@ jobs:
6060
~/.triton/nvidia
6161
~/.triton/json
6262
key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-json-${{ steps.cache-key.outputs.json }}
63-
- # Cache ~/.cache/ccache to speed up compilation.
64-
#
65-
# On branch `main` we always start from an empty cache, i.e. we skip the
66-
# "restore" step. This is to prevent the caches from accumulating stale
67-
# files over time.
68-
name: Restore cache of ccache and Triton compilation artifacts
69-
id: restore-build-cache
70-
if: github.ref != 'refs/heads/main'
71-
uses: actions/cache/restore@v4
72-
with:
73-
path: |
74-
~/.ccache
75-
# Restore the most recent cache entry.
76-
restore-keys: |
77-
triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-
78-
triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-
79-
# We expect this cache key never to hit and for us to fall back
80-
# unconditionally to the restore-key, so it doesn't actually matter
81-
# what we put here (so long as it doesn't hit an existing key).
82-
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}
8363
- name: Inspect cache directories
8464
run: |
8565
mkdir -p ~/.triton
@@ -152,18 +132,6 @@ jobs:
152132
153133
mkdir -p ~/.ccache
154134
du -h -d 1 ~/.ccache
155-
- # If we're on branch `main`, save the ccache Triton compilation artifacts
156-
# to the cache so they can be used by other (non-main) CI runs.
157-
#
158-
# (It wouldn't be a problem to save the cache on every run, because github
159-
# evicts cache entries LRU, but maybe this saves a bit of time in CI.)
160-
name: Save ccache and Triton compilation artifacts to cache
161-
if: github.ref == 'refs/heads/main'
162-
uses: actions/cache/save@v4
163-
with:
164-
path: |
165-
~/.ccache
166-
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}
167135
- name: Clean up caches
168136
# Always cleanup the worker, even if builds or tests failed
169137
if: always()

.github/workflows/integration-tests-nvidia.yml

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,6 @@ jobs:
5757
~/.triton/nvidia
5858
~/.triton/json
5959
key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-json-${{ steps.cache-key.outputs.json }}
60-
- # Cache ~/.cache/ccache to speed up compilation.
61-
#
62-
# On branch `main` we always start from an empty cache, i.e. we skip the
63-
# "restore" step. This is to prevent the caches from accumulating stale
64-
# files over time.
65-
name: Restore cache of ccache and Triton compilation artifacts
66-
id: restore-build-cache
67-
if: github.ref != 'refs/heads/main'
68-
uses: actions/cache/restore@v4
69-
with:
70-
path: |
71-
~/.ccache
72-
# Restore the most recent cache entry.
73-
restore-keys: |
74-
triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-
75-
triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-
76-
# We expect this cache key never to hit and for us to fall back
77-
# unconditionally to the restore-key, so it doesn't actually matter
78-
# what we put here (so long as it doesn't hit an existing key).
79-
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}
8060
- name: Inspect cache directories
8161
run: |
8262
mkdir -p ~/.triton
@@ -130,15 +110,3 @@ jobs:
130110
131111
mkdir -p ~/.ccache
132112
du -h -d 1 ~/.ccache
133-
- # If we're on branch `main`, save the ccache Triton compilation artifacts
134-
# to the cache so they can be used by other (non-main) CI runs.
135-
#
136-
# (It wouldn't be a problem to save the cache on every run, because github
137-
# evicts cache entries LRU, but maybe this saves a bit of time in CI.)
138-
name: Save ccache and Triton compilation artifacts to cache
139-
if: github.ref == 'refs/heads/main'
140-
uses: actions/cache/save@v4
141-
with:
142-
path: |
143-
~/.ccache
144-
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 111 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,32 +1515,133 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) {
15151515
return {};
15161516
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding());
15171517

1518-
// We currently only support transposed [B]F16 MFMA32x32 on CDNA4.
1518+
// We currently only support transposed [B]F16 MFMA32x32 and MFMA16x16 on
1519+
// CDNA4.
15191520
bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
1521+
bool isMfma16 = mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16;
1522+
1523+
auto valShape = valType.getShape();
1524+
// For mfma16x16, to use in-wavefront swap, we need to make sure the tiles
1525+
// used are in one wavefront if there are multiple tiles, which means
1526+
// warpsPerCTA = [numWarps, 1] and at least two tiles along the N dim. For
1527+
// now, it is only possible for FA-like kernels since during mfma generation,
1528+
// the WarpsPerCTA of the head dot in the chain will be reshaped to [numWaprs,
1529+
// 1].
1530+
// TODO: For gemm-like kernel, the transformation here cannot be applied for
1531+
// now and will support it.
1532+
bool validForMfma16 = isMfma16 && valShape.back() >= 16 * 2 &&
1533+
mfmaLayout.getWarpsPerCTA().back() == 1;
1534+
15201535
Type elemType = valType.getElementType();
15211536
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
15221537
mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() &&
1523-
isMfma32))
1538+
(isMfma32 || validForMfma16)))
15241539
return {};
15251540

1526-
auto valShape = valType.getShape();
15271541
LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape);
15281542
auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames());
15291543
StringAttr dimM = mfmaOutDims[0];
15301544
StringAttr dimN = mfmaOutDims[1];
1531-
15321545
auto swapLL = LinearLayout::empty();
15331546
// The rows are kept as is with an identity linear layout.
15341547
swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM);
1535-
// In transposed mfma32 layout, each thread holds 4 consecutive values along N
1536-
// dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11
1537-
// (owned by thread 0-31) every 16 columns to make each thread holds 8
1538-
// elements. This would mean exchange the 2nd and 3rd basis vector from an
1539-
// identity linear layout.
1548+
/*
1549+
clang-format off
1550+
In transposed mfma32 layout, Each thread holds 4 consecutive values along N
1551+
dim. We want to exchange column 4-7 (owned by thread 32-63, BLK0) and column
1552+
8-11 (owned by thread 0-31, BLK1) every 16 columns to make each thread holds 8
1553+
elements. This would mean exchange the 2nd and 3rd basis vector from an
1554+
identity linear layout on tensor elements.
1555+
1556+
Correspondingly, the transposed mfma16 layout, the output of
1557+
transposed of mfma16x16 is:
1558+
1559+
N/register
1560+
M/Lane v0 v1 v2 v3 v4 v5 v6 v7
1561+
-------------------------------------------------------------------------
1562+
row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1563+
-------------------------------------------------------------------------
1564+
row1: 16-31 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1565+
-------------------------------------------------------------------------
1566+
row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1567+
-------------------------------------------------------------------------
1568+
row3: 48-63 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1569+
-------------------------------------------------------------------------
1570+
which means:
1571+
The columns from v0 to v3 are in the one output of mfma16x16 and
1572+
the columns from v4 to v7 are in the one output of mfma16x16,
1573+
1574+
The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor,
1575+
N/register
1576+
-----------------------------------------------
1577+
M/lane |(0, 0) ... (0, 3) | (0, 16) ... (0, 19) |
1578+
|.... | sub-tensor-0 |
1579+
|(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
1580+
-----------------------------------------------
1581+
|(0, 4) ... (0, 7) | (0, 20) ... (0, 23) |
1582+
|sub-tensor-1 | .... |
1583+
|(15, 0) ... (15, 3) | (15, 20) ... (15, 23) |
1584+
-----------------------------------------------
1585+
|(0, 8) ... (0, 11)| (0, 24) ... (0, 27) |
1586+
|.... | sub-tensor-2 |
1587+
|(15, 8) ... (15, 11)| (15, 24) ... (15, 27) |
1588+
-----------------------------------------------
1589+
|(0, 12) ... (0, 15)| (0, 28) ... (0, 31) |
1590+
|sub-tensor-3 | .... |
1591+
|(15, 12) ... (15, 15)| (15, 28) ... (15, 31) |
1592+
-----------------------------------------------
1593+
The basis vector for lane and register are:
1594+
Register = {{0, 1}, {0, 2}}
1595+
Lane = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}
1596+
With this layout, only 4xfp16 can be packed in the final global store.
1597+
1598+
To use 128-bits global store, we need to pack 8 elements, which means the layout looks like:
1599+
N/register
1600+
M/Lane v0 v1 v2 v3 v4 v5 v6 v7
1601+
-------------------------------------------------------------------------
1602+
row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
1603+
-------------------------------------------------------------------------
1604+
row1: 16-31 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
1605+
-------------------------------------------------------------------------
1606+
row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
1607+
-------------------------------------------------------------------------
1608+
row3: 48-63 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
1609+
-------------------------------------------------------------------------
1610+
1611+
The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor:
1612+
N/register
1613+
-----------------------------------------------
1614+
|(0, 0) ... (0, 3) | (0, 4) ... (0, 7) |
1615+
|.... | sub-tensor-1 |
1616+
|(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
1617+
-----------------------------------------------
1618+
|(0, 16) ... (0, 19) | (0, 20) ... (0, 23) |
1619+
|sub-tensor-0 | .... |
1620+
|(15, 16) ... (15, 19)| (15, 20) ... (15, 23) |
1621+
-----------------------------------------------
1622+
|(0, 8) ... (0, 11)| (0, 12) ... (0, 15) |
1623+
|.... | sub-tensor-3 |
1624+
|(15, 8) ... (15, 11)| (15, 12) ... (15, 15) |
1625+
-----------------------------------------------
1626+
|(0, 24) ... (0, 27)| (0, 28) ... (0, 31) |
1627+
|sub-tensor-2 | .... |
1628+
|(15, 24) ... (15, 27)| (15, 28) ... (15, 31) |
1629+
-----------------------------------------------
1630+
which means we need to exchange sub-tensor-0 with sub-tensor-1 and sub-tensor-2 and sub-tensor-3.
1631+
And basis vector for lane and register are:
1632+
Register = {{0, 1}, {0, 2}, {0, 4}}
1633+
Lane = {{1, 0}, {2, 0, [4, 0}, {8, 0}, {0, 16}, {0, 8}}
1634+
1635+
The steps to get this layout are, firstly we check the last dim of WarpsPerCTA is 1, so we can use v_permlane16.
1636+
Then, we exchange the 2nd and 4th elements in the basis vector of an identity linear and then it will be composed with
1637+
the original mfma16 LL.
1638+
clang-format on
1639+
*/
1640+
auto destIdxInBases = isMfma32 ? 3 : 4;
15401641
std::vector<std::vector<int32_t>> dimNBases(mfmaLL.getOutDimSizeLog2(dimN));
15411642
std::generate(dimNBases.begin(), dimNBases.end(),
15421643
[i = 0]() mutable { return std::vector<int32_t>{1 << i++}; });
1543-
std::swap(dimNBases[2], dimNBases[3]);
1644+
std::swap(dimNBases[2], dimNBases[destIdxInBases]);
15441645
swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN});
15451646

15461647
return mfmaLL.compose(swapLL);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,13 @@ LogicalResult getConvertBackwardSlice(
941941
auto srcEncoding = inferSrcEncoding(definingOp, encoding);
942942
if (!srcEncoding)
943943
return failure();
944+
// If the infered layout matches the original one we don't need to keep
945+
// propagating.
946+
if (auto operandType =
947+
dyn_cast<RankedTensorType>(operand.get().getType())) {
948+
if (srcEncoding == operandType.getEncoding())
949+
continue;
950+
}
944951
enqueue(operand, srcEncoding);
945952
}
946953
continue;

python/test/unit/language/test_core.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5989,14 +5989,9 @@ def kernel(Out):
59895989
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2),
59905990
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1),
59915991
MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]),
5992-
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2),
59935992
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2),
59945993
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2),
5995-
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2),
59965994
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8),
5997-
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8),
5998-
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8),
5999-
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8),
60005995
SliceLayout(
60015996
dim=1,
60025997
parent=DotOperandLayout(parent=MmaLayout([3, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [16, 32, 16]),

python/test/unit/language/test_frontend.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ def __init__(self, a, b):
179179
def create(a):
180180
return AggregateWithConstexpr(a, tl.constexpr(42))
181181

182+
@triton.jit
183+
def modify(self, a):
184+
self.a = a
185+
return self
186+
182187

183188
@triton.jit
184189
def add_rhs_constexpr(agg):
@@ -196,3 +201,59 @@ def test_aggregate_with_constexpr():
196201
# CHECK: tt.func private @"test_frontend.add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
197202
# CHECK: %cst = arith.constant dense<42> : tensor<4xi32>
198203
# CHECK: arith.addi %arg0, %cst : tensor<4xi32>
204+
205+
206+
@tl.constexpr_function
207+
def constexpr_function(x):
208+
return x + 1
209+
210+
211+
@filecheck_test
212+
@triton.jit
213+
def test_constexpr_function_from_jit():
214+
# CHECK-LABEL: test_constexpr_function
215+
x: tl.constexpr = constexpr_function(7)
216+
# CHECK: make_range {end = 8 : i32, start = 0 : i32}
217+
tl.arange(0, x)
218+
219+
220+
def test_constexpr_function_from_python():
221+
assert constexpr_function(7) == 8
222+
223+
224+
@triton.jit
225+
def swap(pair):
226+
return pair.second, pair.first
227+
228+
229+
@filecheck_test
230+
@triton.jit
231+
def test_assign_tuple_attrs():
232+
# CHECK-LABEL: test_assign_tuple_attrs
233+
p = Pair(tl.arange(0, 4), tl.arange(4, 8))
234+
# CHECK: [[P:%.*]]:2 = tt.call @{{.*}}swap
235+
p.first, p.second = swap(p)
236+
# CHECK: call @{{.*}}anchor{{.*}}([[P]]#0)
237+
# CHECK: call @{{.*}}anchor{{.*}}([[P]]#1)
238+
anchor(p.first)
239+
anchor(p.second)
240+
241+
242+
@filecheck_test
243+
@triton.jit
244+
def test_reassign_aggregate_with_constexpr():
245+
# CHECK-LABEL: test_reassign_aggregate_with_constexpr
246+
agg = AggregateWithConstexpr.create(tl.arange(0, 4))
247+
var = 1
248+
# CHECK: [[AGG:%.*]] = scf.if {{.*}} -> (tensor<4xi32>)
249+
# CHECK: [[VALUE:%.*]] = tt.call {{.*}}modify
250+
# CHECK: yield [[VALUE]]
251+
# CHECK: else
252+
# CHECK: [[VALUE:%.*]] = tt.call {{.*}}modify
253+
# CHECK: yield [[VALUE]]
254+
if var == 0:
255+
agg = agg.modify(tl.arange(4, 8))
256+
else:
257+
agg = agg.modify(tl.arange(8, 12))
258+
# CHECK: call @{{.*}}anchor{{.*}}([[AGG]])
259+
anchor(agg)

python/triton/compiler/code_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,8 @@ def assignTarget(self, target, value):
571571
if isinstance(target, ast.Subscript):
572572
return self.visit_Subscript_Store(target, value)
573573
if isinstance(target, ast.Tuple):
574-
for i, name in enumerate(target.elts):
575-
self.set_value(self.visit(name), value.values[i])
574+
for i, target in enumerate(target.elts):
575+
self.assignTarget(target, value.values[i])
576576
return
577577
if isinstance(target, ast.Attribute):
578578
base = self.visit(target.value)

python/triton/language/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ class constexpr_type(base_type):
178178
def __init__(self, value):
179179
self.value = value
180180

181+
def __eq__(self, other):
182+
return self.value == other.value
183+
181184
def __repr__(self) -> str:
182185
return f"constexpr[{self.value}]"
183186

@@ -338,7 +341,7 @@ def constexpr_function(f):
338341
@wraps(f)
339342
def wrapper(*args, **kwargs):
340343
# de-constexpr arguments and discard the _builder keyword argument:
341-
args = [getattr(x, "value", x) for x in args]
344+
args = [_unwrap_if_constexpr(x) for x in args]
342345
kwargs = {k: getattr(v, "value", v) for (k, v) in kwargs.items() if k != "_builder"}
343346

344347
# call the raw Python function f:

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
271271

272272
// Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
273273
// GFX950: rocdl.make.buffer.rsrc
274-
// GFX950: rocdl.ds_bpermute
274+
// Src ptrs are contiguous so we do expect to bypass the ds_bpermute (see lowering to LLVM)
275+
// GFX950-NOT: rocdl.ds_bpermute
275276
// GFX950: rocdl.raw.ptr.buffer.load.lds
276277
// GFX950-NOT: rocdl.raw.ptr.buffer.load.lds
277278

test/Conversion/cvt_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ tt.func private @convert_layout_blocked_blocked(%arg0: tensor<16x16xi32, #blocke
127127
// to this, we choose to fall back to the shared memory implementation.
128128

129129
// CHECK-NOT: shfl.sync.idx
130-
// CHECK: st.shared
130+
// CHECK: store
131131

132132
%0 = ttg.convert_layout %arg0 : tensor<16x16xi32, #blocked0> -> tensor<16x16xi32, #blocked1>
133133
tt.return %0 : tensor<16x16xi32, #blocked1>

0 commit comments

Comments
 (0)