Skip to content

Commit cb9a390

Browse files
Merge OpenAI Triton commit 2d6fb76 (#4355)
This PR change the Triton base from abd3bb0 to 2d6fb76 (May 22). Pass rate: 94.95% -> 94.63% Please do not squash and merge this PR. A770 CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15325583238
2 parents 91066ce + 34f4c17 commit cb9a390

File tree

38 files changed

+1173
-390
lines changed

38 files changed

+1173
-390
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
271271
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
272272
int numWarps);
273273

274+
std::optional<LinearLayout>
275+
getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType,
276+
int numWarps);
277+
274278
// Return a layout valid for TMemLoad op for a tmem layout of block MxN that
275279
// distribute the data long M for the warp groups. This doesn't affect the TMem
276280
// layout it just returns a distributed layout compatible for tmem_load.

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4444
"STORE_TMEM_TO_GLOBAL_BYPASS_SMEM",
4545
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",
4646
"TRITON_F32_DEFAULT",
47+
"TRITON_PREFER_TMEM_16x256_LAYOUT",
4748
"TRITON_INTEL_ADVANCED_PATH",
4849
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
4950
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,92 @@ LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
16181618
return combineCtaCgaWithShape(regLanes, CTALayout, scaleType.getShape());
16191619
}
16201620

1621+
std::optional<LinearLayout>
1622+
getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType,
1623+
int numWarps) {
1624+
// Too small to distribute on two warp groups while using 16x256 message.
1625+
if (numWarps == 8 && M == 64 && N <= 16 &&
1626+
oldType.getElementTypeBitWidth() < 32) {
1627+
return {};
1628+
}
1629+
assert(numWarps == 4 || numWarps == 8);
1630+
auto ctaLayout = getCTALayout(oldType.getEncoding());
1631+
SmallVector<int64_t> shape = getShapePerCTA(oldType);
1632+
MLIRContext *ctx = ctaLayout.getContext();
1633+
1634+
using basisT = std::vector<std::vector<int32_t>>;
1635+
StringAttr kRegister = StringAttr::get(ctx, "register");
1636+
StringAttr kLane = StringAttr::get(ctx, "lane");
1637+
StringAttr kWarp = StringAttr::get(ctx, "warp");
1638+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, 2);
1639+
1640+
unsigned numElementsPerThread = 256 / oldType.getElementTypeBitWidth();
1641+
int kWidth = 64 / oldType.getElementTypeBitWidth();
1642+
// Follow the layout given by a tmem load using this layout for the inner
1643+
// shape:
1644+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
1645+
LinearLayout innerTile =
1646+
nvidiaMmaTile(ctx, {8, numElementsPerThread}, kWidth, {1, 0}, {0, 1});
1647+
innerTile =
1648+
innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]);
1649+
// Then distribute the rest along warpgroups and registers.
1650+
// Then the last warp distribute along M or N following the same order as
1651+
// in getTmemLoadStoreLayout32x32b. This allows us to use the same lowering to
1652+
// tmem for load and store. This part could be generalized by making the
1653+
// lowering of tmem load and store rely more on linear layout.
1654+
bool distributeMAlongWarps = false;
1655+
bool distributeNAlongWarps = false;
1656+
// Figure out how to distribute acorss warpgroups.
1657+
if (numWarps == 8) {
1658+
if (shape[0] > 128) {
1659+
distributeMAlongWarps = true;
1660+
} else {
1661+
distributeNAlongWarps = true;
1662+
}
1663+
}
1664+
int nBase = numElementsPerThread;
1665+
int maxRegN =
1666+
std::min(N, distributeNAlongWarps ? (int)shape[1] / 2 : (int)shape[1]);
1667+
if (maxRegN / nBase > 1) {
1668+
innerTile = innerTile * LinearLayout::identity1D(maxRegN / nBase, kRegister,
1669+
outDimNames[1]);
1670+
}
1671+
if (M != 64) {
1672+
innerTile =
1673+
innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]);
1674+
}
1675+
// Distribute M along 4 warps to satisfy TMEM requirements.
1676+
innerTile = innerTile * LinearLayout::identity1D(4, kWarp, outDimNames[0]);
1677+
1678+
// Fill out the rest of the shape with M first then N.
1679+
int numMRegDim = std::min(128, (int)shape[0]) / M;
1680+
if (numMRegDim > 1) {
1681+
innerTile = innerTile *
1682+
LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]);
1683+
}
1684+
// Dim M=128 should be distributed on the second warp group.
1685+
int nextDim = 128;
1686+
if (distributeMAlongWarps) {
1687+
innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[0]);
1688+
nextDim <<= 1;
1689+
}
1690+
numMRegDim = shape[0] / nextDim;
1691+
if (numMRegDim > 1) {
1692+
innerTile = innerTile *
1693+
LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]);
1694+
}
1695+
int maxN = distributeNAlongWarps ? shape[1] / 2 : shape[1];
1696+
int numNRegDim = maxN / maxRegN;
1697+
if (numNRegDim > 1) {
1698+
innerTile = innerTile *
1699+
LinearLayout::identity1D(numNRegDim, kRegister, outDimNames[1]);
1700+
}
1701+
if (distributeNAlongWarps) {
1702+
innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[1]);
1703+
}
1704+
return combineCtaCgaWithShape(innerTile, ctaLayout, oldType.getShape());
1705+
}
1706+
16211707
LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
16221708
int numWarps) {
16231709
assert(numWarps == 8);

lib/Dialect/TritonGPU/Transforms/Pipeliner/MMAv5PipelineUtility.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ bool ttng::MMAv5PipelineableOperandsHelper::comesFromLoadOrOutsideLoop(
2525
while (isa<ttg::MemDescTransOp, ttg::MemDescReshapeOp>(v.getDefiningOp())) {
2626
v = v.getDefiningOp()->getOperand(0);
2727
}
28+
if (auto tmemAlloc = dyn_cast<ttng::TMEMAllocOp>(v.getDefiningOp())) {
29+
foundLoad = tmemAlloc;
30+
return false;
31+
}
2832
auto localAlloc = dyn_cast<ttg::LocalAllocOp>(v.getDefiningOp());
2933
if (!localAlloc) {
3034
return false;

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
257257
llvm::zip(wsOp.getPartitionRegions(), partitionNumWarps,
258258
wsOp.getPartitionNumWarps(), maxTensorRegs, estRegUsage)) {
259259
// "Guess" the register usage for each partition.
260-
estRegs = tensorRegs ? 72 : 24;
260+
estRegs = tensorRegs ? 88 : 24;
261261

262262
// Layouts need to be reassigned if the number of warps changed and there
263263
// are tensor computations.

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,15 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
208208
Operation *op = operandViews.pop_back_val();
209209
if (!op->hasOneUse() || !op->hasTrait<OpTrait::MemDescViewTrait>())
210210
continue;
211+
212+
// Duplicate the op if necessary to ensure the MMA op is the only user.
213+
if (!llvm::all_of(op->getUsers(),
214+
[&](Operation *user) { return user == mmaOp; })) {
215+
Operation *viewOp = OpBuilder(op).clone(*op);
216+
mmaOp->replaceUsesOfWith(op->getResult(0), viewOp->getResult(0));
217+
op = viewOp;
218+
}
219+
211220
schedule.trySchedule(mmaPartition, op);
212221
if (Operation *defOp = op->getOperand(0).getDefiningOp())
213222
operandViews.push_back(defOp);

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323

2424
#include "triton/Dialect/Triton/IR/Dialect.h"
2525
#include "triton/Dialect/Triton/IR/Utility.h"
26+
#include "triton/Tools/Sys/GetEnv.hpp"
2627

2728
#include <numeric>
2829

2930
#include "mlir/IR/DialectImplementation.h"
3031
#include "mlir/IR/OpImplementation.h"
3132
#include "triton/Analysis/Utility.h"
3233
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
34+
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
3335
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
3436
#include "llvm/ADT/TypeSwitch.h"
3537
#include "llvm/Support/Debug.h"
@@ -96,8 +98,9 @@ TMemAllocation getTmemAllocSizes(MemDescType memDescType) {
9698
return TMemAllocation(numColumn, numRows);
9799
}
98100

99-
Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
100-
RankedTensorType oldType, unsigned numWarps) {
101+
Attribute getTmemLoadStoreLayout32x32b(unsigned M, unsigned N,
102+
RankedTensorType oldType,
103+
unsigned numWarps) {
101104
assert(numWarps == 4 || numWarps == 8);
102105
auto shape = getShapePerCTA(oldType);
103106
assert(shape.size() == 2);
@@ -146,6 +149,20 @@ Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
146149
warpsPerCTA, order, ctaLayout);
147150
}
148151

152+
Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
153+
RankedTensorType oldType, unsigned numWarps) {
154+
bool prefer16x256 =
155+
triton::tools::getBoolEnv("TRITON_PREFER_TMEM_16x256_LAYOUT");
156+
if (prefer16x256) {
157+
std::optional<LinearLayout> ll =
158+
getTmemLoadStoreLayout16x256(M, N, oldType, numWarps);
159+
if (ll) {
160+
return LinearEncodingAttr::get(oldType.getContext(), *ll);
161+
}
162+
}
163+
return getTmemLoadStoreLayout32x32b(M, N, oldType, numWarps);
164+
}
165+
149166
bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
150167
MemDescType memType, int numWarps) {
151168
auto tmemEnc = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
@@ -159,6 +176,8 @@ bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
159176
return false;
160177
auto CTALayout = getCTALayout(tensorType.getEncoding());
161178
auto shapePerCTA = mlir::triton::gpu::getShapePerCTA(tensorType);
179+
if (numWarps != 8)
180+
return false;
162181
LinearLayout llLayout =
163182
getTmemLoadLayoutSplitLongM(M, N, tensorType, numWarps);
164183
return llEncoding.getLinearLayout() == llLayout;
@@ -170,7 +189,6 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
170189
MemDescType memType) {
171190
int numWarps = lookupNumWarps(op);
172191
assert(numWarps % 4 == 0);
173-
int numWarpGroups = numWarps / 4;
174192
if (isa<triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
175193
memType.getEncoding())) {
176194
return tensorType.getEncoding() ==
@@ -184,8 +202,17 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
184202
int blockN = attr.getBlockN();
185203
if (isDistributedLayoutSplitMTmemLoadStore(tensorType, memType, numWarps))
186204
return true;
187-
Attribute layout =
188-
nvidia_gpu::getTmemCompatibleLayout(blockM, blockN, tensorType, numWarps);
205+
206+
auto ll16x256 =
207+
getTmemLoadStoreLayout16x256(blockM, blockN, tensorType, numWarps);
208+
if (ll16x256.has_value() &&
209+
areLayoutsEquivalent(
210+
tensorType.getShape(),
211+
LinearEncodingAttr::get(tensorType.getContext(), ll16x256.value()),
212+
tensorType.getEncoding()))
213+
return true;
214+
Attribute layout = nvidia_gpu::getTmemLoadStoreLayout32x32b(
215+
blockM, blockN, tensorType, numWarps);
189216
// TODO: Add support for more layout compatible with tmem load/store. There
190217
// will only be a discret set of layout possible due to the limiations of
191218
// tmem_load/store.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ files = [
99
"python/triton/runtime/build.py",
1010
"python/triton/_utils.py",
1111
"python/test/unit/test_knobs.py",
12+
"python/test/unit/runtime/test_build.py",
1213
"python/test/unit/runtime/test_compilation_listener.py",
1314
]
1415
exclude = ["/build/"]

python/test/unit/language/test_core.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6017,10 +6017,6 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):
60176017
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path):
60186018
if str(src_layout) == str(dst_layout):
60196019
pytest.xfail("Do not convert same layout")
6020-
if (isinstance(src_layout, DotOperandLayout)
6021-
and isinstance(interm_layout, SharedLayout)) or (isinstance(dst_layout, DotOperandLayout)
6022-
and isinstance(interm_layout, SharedLayout)):
6023-
pytest.xfail("DotOperandLayout <-> SharedLayout conversion is not completely supported")
60246020
if is_hip() or is_xpu():
60256021
try:
60266022
scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N))

python/test/unit/language/test_matmul.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def matmul_kernel( #
3434
stride_cm, stride_cn, #
3535
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
3636
NUM_STAGES: tl.constexpr, SCALE_A: tl.constexpr = None, PRECISION: tl.constexpr = "ieee",
37-
A_TRANS: tl.constexpr = False, EPILOGUE_SUBTILE: tl.constexpr = False):
37+
A_TRANS: tl.constexpr = False, EPILOGUE_SUBTILE: tl.constexpr = False, dummy: tl.constexpr = 0):
3838
pid = tl.program_id(axis=0)
3939
num_pid_m = tl.cdiv(M, BLOCK_M)
4040
pid_m = pid % num_pid_m
@@ -97,8 +97,9 @@ def get_src_element_ty_size(dtype_str):
9797
@pytest.mark.parametrize("NUM_CTAS", [1, 2])
9898
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
9999
@pytest.mark.parametrize("EPILOGUE_SUBTILE", [True, False])
100+
@pytest.mark.parametrize("LAYOUT_16x256", [True, False])
100101
def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, NUM_CTAS, device,
101-
EPILOGUE_SUBTILE):
102+
EPILOGUE_SUBTILE, LAYOUT_16x256, monkeypatch):
102103
if NUM_CTAS > 1 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 9):
103104
pytest.xfail("Clusters requires nvidia compute capability >= 9")
104105
if is_hip() and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str)
@@ -118,6 +119,8 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
118119
pytest.skip("multi-CTAs is broken for mmav2")
119120
if EPILOGUE_SUBTILE and not is_xpu() and (is_hip() or NUM_CTAS > 1 or BLOCK_N >= 512):
120121
pytest.skip("creates convert layout too big to fit in smem")
122+
if LAYOUT_16x256 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 10):
123+
pytest.xfail("skip forcing tmem layout on non blackwell targets.")
121124
M, N, K = 1024, 512, 256
122125
torch.manual_seed(42)
123126
precision = "tf32" if dtype_src_str == "tensorfloat32" else "ieee"
@@ -133,12 +136,16 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
133136
b = torch.randn(K, N, dtype=dtype_src, device=device)
134137
A = a
135138
B = b
139+
# pass a dummy constexpr argument to force recompilation.
140+
if LAYOUT_16x256:
141+
monkeypatch.setenv("TRITON_PREFER_TMEM_16x256_LAYOUT", "1")
136142
dtype_dst = getattr(torch, dtype_dst_str)
137143
output = torch.empty((M, N), dtype=dtype_dst, device=device)
138144
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
139145
k = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0),
140146
output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES, PRECISION=precision,
141-
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, EPILOGUE_SUBTILE=EPILOGUE_SUBTILE)
147+
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, EPILOGUE_SUBTILE=EPILOGUE_SUBTILE,
148+
dummy=LAYOUT_16x256)
142149
ref_out = torch.matmul(A, B).to(torch.float32)
143150
output = output.to(torch.float32)
144151
if dtype_src_str == "float32":
@@ -161,6 +168,13 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
161168
ttgir = k.asm["ttgir"]
162169
count = ttgir.count("ttng.tc_gen5_mma")
163170
assert count == 2, "The TTGIR does not match the expected pattern."
171+
ptx = k.asm["ptx"]
172+
if LAYOUT_16x256:
173+
assert "16x256b" in ptx, "PTX does not contain 16x256b"
174+
else:
175+
if "32x32b" not in ptx and "16x32b" not in ptx:
176+
print(ptx)
177+
assert ("32x32b" in ptx) or ("16x32b" in ptx), "PTX does not contain 32x32b or 16x32b"
164178

165179

166180
# persistent matmul with fused loops

0 commit comments

Comments
 (0)