Skip to content

Commit c49212a

Browse files
Merge OpenAI Triton commit b220c76 (#4997)
This PR change the Triton base from a9d79a6 to b220c76 (Aug 14). Pass rate: 98.85%
2 parents 5c013a1 + e1b256f commit c49212a

File tree

25 files changed

+1041
-228
lines changed

25 files changed

+1041
-228
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ jobs:
166166
# Reenable test_functional_regression.py once it's fixed
167167
cd python/test/regression
168168
python3 -m pytest -s -n 8 ./test_cast_matmul.py
169+
- name: Run microbenchmark tests
170+
run: |
171+
python3 python/test/microbenchmark/launch_overhead.py
169172
- name: Run Proton tests
170173
run: |
171174
unset HIP_VISIBLE_DEVICES

.github/workflows/integration-tests-nvidia.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ jobs:
9898
run: make test-interpret
9999
- name: Run regression tests
100100
run: make test-regression
101+
- name: Run microbenchmark tests
102+
# Microbenchmark never fail but running them gives us an easy way to track performance changes.
103+
run: make test-microbenchmark
101104
- name: Run C++ unittests
102105
run: make test-cpp
103106
- name: Run Proton tests

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ test-gluon: all
6060
test-regression: all
6161
$(PYTEST) -s -n $(NUM_PROCS) python/test/regression
6262

63+
.PHONY: test-microbenchmark
64+
test-microbenchmark: all
65+
$(PYTHON) python/test/microbenchmark/launch_overhead.py
66+
6367
.PHONY: test-interpret
6468
test-interpret: all
6569
cd python/test/unit && TRITON_INTERPRET=1 $(PYTEST) -s -n 16 -m interpreter cuda language/test_core.py language/test_standard.py \

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ std::optional<int> maybeLookupNumWarps(Operation *op);
5656
// FIXME: Make this API and that of maybeLookupNumWarps consistent!
5757
// Utility to find the number of threads per warp
5858
int lookupThreadsPerWarp(OpBuilder &rewriter);
59+
int lookupNumCTAs(OpBuilder &rewriter);
5960

6061
template <typename Key, typename Value> class Cache {
6162
public:

include/triton/Tools/LinearLayout.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,10 @@ class LinearLayout {
432432
return isSurjective() && getTotalInDimSize() == getTotalOutDimSize();
433433
}
434434

435+
// Remove a 1-sized dimension from the layout.
436+
[[nodiscard]] LinearLayout unsqueezeIn(StringAttr dim) const;
437+
[[nodiscard]] LinearLayout unsqueezeOut(StringAttr dim) const;
438+
435439
const BasesT &getBases() const { return bases; }
436440

437441
// Get the pos'th basis vector for the inDim -> outDim mapping.

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3417,13 +3417,20 @@ int triton::gpu::lookupNumWarps(Operation *op) {
34173417

34183418
int triton::gpu::lookupThreadsPerWarp(OpBuilder &rewriter) {
34193419
assert(rewriter.getInsertionBlock() && "expected an insertion point");
3420-
Operation *op = rewriter.getInsertionBlock()->getParentOp();
3421-
while (op && !isa<ModuleOp>(op))
3422-
op = op->getParentOp();
3420+
Operation *op =
3421+
rewriter.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
34233422
assert(op && "cannot create thread ID outside of module");
34243423
return triton::gpu::TritonGPUDialect::getThreadsPerWarp(cast<ModuleOp>(op));
34253424
}
34263425

3426+
int triton::gpu::lookupNumCTAs(OpBuilder &rewriter) {
3427+
assert(rewriter.getInsertionBlock() && "expected an insertion point");
3428+
Operation *op =
3429+
rewriter.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
3430+
assert(op && "cannot create thread ID outside of module");
3431+
return triton::gpu::TritonGPUDialect::getNumCTAs(cast<ModuleOp>(op));
3432+
}
3433+
34273434
bool triton::gpu::areLayoutsEquivalent(ArrayRef<int64_t> shape,
34283435
DistributedEncodingTrait lhs,
34293436
DistributedEncodingTrait rhs) {

lib/Tools/LinearLayout.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,28 @@ LinearLayout LinearLayout::pseudoinvert() const {
11291129
return identity.invertAndCompose(*this);
11301130
}
11311131

1132+
LinearLayout LinearLayout::unsqueezeIn(StringAttr dim) const {
1133+
assert(getInDimSize(dim) == 1);
1134+
SmallVector<std::pair<StringAttr, int32_t>> newInDims;
1135+
for (auto inDim : getInDimNames()) {
1136+
if (inDim != dim) {
1137+
newInDims.push_back({inDim, getInDimSize(inDim)});
1138+
}
1139+
}
1140+
return reshapeIns(newInDims);
1141+
}
1142+
1143+
LinearLayout LinearLayout::unsqueezeOut(StringAttr dim) const {
1144+
assert(getOutDimSize(dim) == 1);
1145+
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
1146+
for (auto [outDim, outDimSize] : getOutDims()) {
1147+
if (outDim != dim) {
1148+
newOutDims.push_back({outDim, outDimSize});
1149+
}
1150+
}
1151+
return LinearLayout(bases, newOutDims, isSurjective());
1152+
}
1153+
11321154
llvm::MapVector<StringAttr, int32_t>
11331155
LinearLayout::getFreeVariableMasks() const {
11341156
std::unique_ptr<uint64_t[]> mat = getMatrix(*this);

python/src/gluon_ir.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ py::object layoutToGluon(Attribute layout) {
217217

218218
return layouts.AMDMFMALayout(
219219
amdMfma.getVersion(), instrShape, amdMfma.getIsTransposed(),
220-
toStdVector(amdMfma.getWarpsPerCTA()),
221-
toStdVector(amdMfma.getTilesPerWarp()), layouts.GluonDType(typeName),
220+
toStdVector(amdMfma.getWarpsPerCTA()), layouts.GluonDType(typeName),
221+
toStdVector(amdMfma.getTilesPerWarp()),
222222
toStdVector(ctaLayout.getCTAsPerCGA()),
223223
toStdVector(ctaLayout.getCTASplitNum()),
224224
toStdVector(ctaLayout.getCTAOrder()));
@@ -325,13 +325,12 @@ void init_gluon_ir(py::module &&m) {
325325
})
326326
.def("get_amd_mfma_layout",
327327
[](GluonOpBuilder &self, unsigned version,
328+
std::vector<unsigned> &instrShape, bool transposed,
329+
std::vector<unsigned> &warpsPerCta, mlir::Type elemType,
328330
std::vector<unsigned> &tilesPerWarp,
329-
std::vector<unsigned> &warpsPerCta,
330331
std::vector<unsigned> &ctasPerCga,
331332
std::vector<unsigned> &ctaSplitNum,
332-
std::vector<unsigned> &ctaOrder,
333-
std::vector<unsigned> &instrShape, bool transposed,
334-
mlir::Type elemType) -> Attribute {
333+
std::vector<unsigned> &ctaOrder) -> Attribute {
335334
auto ctx = self.getContext();
336335
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
337336
ctx, ctasPerCga, ctaSplitNum, ctaOrder);

python/test/gluon/test_core.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import pytest
33

4-
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hopper_or_newer, is_hopper
4+
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hip_cdna3, is_hip_cdna4, is_hopper_or_newer, is_hopper
55
from triton.experimental import gluon
66
from triton.experimental.gluon import language as ttgl
77
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier
@@ -143,3 +143,66 @@ def test_warpgroup_mma(ASYNC):
143143
ref = torch.matmul(a, b)
144144

145145
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1)
146+
147+
148+
@pytest.mark.parametrize("M, N, K", [(32, 32, 16), (16, 16, 32)])
149+
@pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16'])
150+
@pytest.mark.parametrize("num_warps", [4, 8])
151+
@pytest.mark.parametrize("cdna_version", [3, 4])
152+
def test_amd_mfma(M, N, K, in_dtype, num_warps, cdna_version):
153+
154+
@gluon.jit
155+
def kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, #
156+
stride_bk, stride_bn, #
157+
stride_cm, stride_cn, BLOCK_SIZE_M: ttgl.constexpr, BLOCK_SIZE_N: ttgl.constexpr,
158+
BLOCK_SIZE_K: ttgl.constexpr, blocked: ttgl.constexpr, mfma_layout: ttgl.constexpr):
159+
dot_a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=8)
160+
dot_b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=8)
161+
162+
offs_am = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked))
163+
offs_bn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked))
164+
165+
offs_ak = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(0, blocked))
166+
offs_bk = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(1, blocked))
167+
offs_a = offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak
168+
offs_b = offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn
169+
170+
a = ttgl.amd.cdna3.buffer_load(ptr=a_ptr, offsets=offs_a)
171+
b = ttgl.amd.cdna3.buffer_load(ptr=b_ptr, offsets=offs_b)
172+
a1 = ttgl.convert_layout(a, layout=dot_a_layout)
173+
b1 = ttgl.convert_layout(b, layout=dot_b_layout)
174+
acc = ttgl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], ttgl.float32, mfma_layout)
175+
c = ttgl.amd.cdna3.mfma(a1, b1, acc)
176+
c = ttgl.convert_layout(c, layout=blocked)
177+
c = c.to(a_ptr.dtype.element_ty)
178+
179+
offs_cm = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked))
180+
offs_cn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked))
181+
offs_c = offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
182+
ttgl.amd.cdna3.buffer_store(stored_value=c, ptr=c_ptr, offsets=offs_c)
183+
184+
if not is_hip_cdna4() and not is_hip_cdna3():
185+
pytest.skip("mfma quires target to be CDNA3 or CDNA4")
186+
187+
if is_hip_cdna3() and cdna_version != 3:
188+
pytest.skip("On CDNA3 target, skip if mfma version is not 3")
189+
190+
if is_hip_cdna4() and cdna_version != 4:
191+
pytest.skip("On CDNA4 target, skip if mfma version is not 4")
192+
193+
elem_type = torch.float16 if in_dtype == 'float16' else torch.bfloat16
194+
a = torch.randn((M, K), device='cuda', dtype=elem_type) - 0.5
195+
b = torch.randn((K, N), device='cuda', dtype=elem_type) - 0.5
196+
c = torch.empty((M, N), device=a.device, dtype=elem_type)
197+
nonkdim: ttgl.constexpr = 32
198+
blocked: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[4, 4], threads_per_warp=[4, 16],
199+
warps_per_cta=[num_warps, 1], order=[1, 0])
200+
mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=cdna_version, instr_shape=[nonkdim, nonkdim],
201+
transposed=True, warps_per_cta=[num_warps, 1])
202+
203+
kernel[1, 1](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=M,
204+
BLOCK_SIZE_N=N, BLOCK_SIZE_K=K, blocked=blocked, mfma_layout=mfma_layout, num_warps=num_warps)
205+
206+
ref = torch.matmul(a, b)
207+
triton_output = c
208+
torch.testing.assert_close(ref, triton_output)

0 commit comments

Comments
 (0)