Skip to content

Commit 6bf9a4b

Browse files
Merge commit '818e892af90a0eb7fcd4d2fe29db908bf542c9ed'
2 parents 4094ea4 + 818e892 commit 6bf9a4b

File tree

22 files changed

+1079
-828
lines changed

22 files changed

+1079
-828
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,6 +2354,18 @@ struct TritonGPUInferLayoutInterface
23542354
return success();
23552355
}
23562356

2357+
if (auto enc = dyn_cast<PaddedSharedEncodingAttr>(operandEncoding)) {
2358+
if (failed(checkRank(enc.getRank())))
2359+
return failure();
2360+
2361+
CTALayoutAttr ctaLayout =
2362+
permuteCTALayout(ctx, enc.getCTALayout(), order);
2363+
resultEncoding = PaddedSharedEncodingAttr::get(
2364+
ctx, enc.getIntervals(), enc.getPaddings(),
2365+
applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout);
2366+
return success();
2367+
}
2368+
23572369
auto ll = toLinearLayout(shape, operandEncoding);
23582370
auto transposedLl = transposeLinearLayout(ll, order);
23592371
resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl));

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ void AutomaticWarpSpecialization::runOnOperation() {
4242
// pm.addPass(arith::createIntRangeOptimizationsPass());
4343
pm.addPass(createSCCPPass());
4444
pm.addPass(createCSEPass());
45-
pm.addPass(createTritonGPUPartitionLoops());
45+
pm.addPass(createNVWSAssignStagePhase());
4646
pm.addPass(createNVWSLowerAref());
47+
pm.addPass(createTritonGPUPartitionLoops());
4748
pm.addPass(createNVWSLowerWarpGroup());
4849
if (failed(runPipeline(pm, getOperation())))
4950
return signalPassFailure();

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -811,16 +811,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
811811
Value lastIndex = loop.getResult(index.getArgNumber() - 1);
812812
Value lastPhase = loop.getResult(phase.getArgNumber() - 1);
813813
Value lastBar = createSingleBufferView(b, nodes.back().barNext, lastIndex);
814-
auto waitBarrierOp = b.create<ttng::WaitBarrierOp>(lastBar, lastPhase);
815-
auto node_front = nodes.front();
816-
auto partition = schedule.getPartition(inBody(node_front.op));
817-
PartitionBuilder b(waitBarrierOp->getLoc(), waitBarrierOp);
818-
lastBar.getDefiningOp()->setAttr(kWarpSpecializeTagAttrName,
819-
b.getI32IntegerAttr(schedule.getTag()));
820-
waitBarrierOp->setAttr(kWarpSpecializeTagAttrName,
821-
b.getI32IntegerAttr(schedule.getTag()));
822-
b.assignPartition(lastBar.getDefiningOp(), *partition);
823-
b.assignPartition(waitBarrierOp, *partition);
814+
b.create<ttng::WaitBarrierOp>(lastBar, lastPhase);
824815
}
825816

826817
llvm::SetVector<Operation *> predOps;

lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ std::pair<Value, AccessRange> findBufferAccess(Value a);
6262

6363
std::pair<Value, AccessRange>
6464
findBufferAccessMemdescSubview(Operation *subview) {
65-
OpBuilder builder(subview->getContext());
65+
OpBuilder builder(subview);
6666
Location loc = subview->getLoc();
6767
TypedValue<ttg::MemDescType> src;
6868
SmallVector<int64_t> shape;

python/src/gluon_ir.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ struct GluonLayouts {
9797
py::handle NVMMASharedLayout;
9898
py::handle SwizzledSharedLayout;
9999
py::handle AMDMFMALayout;
100+
py::handle PaddedSharedLayout;
100101
py::handle GluonDType;
101102

102103
GluonLayouts() {
@@ -116,6 +117,8 @@ struct GluonLayouts {
116117
SwizzledSharedLayout =
117118
py::object(layouts.attr("SwizzledSharedLayout")).release();
118119
AMDMFMALayout = py::object(amdLayouts.attr("AMDMFMALayout")).release();
120+
PaddedSharedLayout =
121+
py::object(layouts.attr("PaddedSharedLayout")).release();
119122

120123
auto core = py::module::import("triton.language.core");
121124
GluonDType = py::object(core.attr("dtype")).release();
@@ -199,7 +202,6 @@ py::object layoutToGluon(Attribute layout) {
199202
} else if (auto amdMfma = dyn_cast<ttg::AMDMfmaEncodingAttr>(layout)) {
200203
auto ctaLayout = amdMfma.getCTALayout();
201204
std::vector<unsigned> instrShape{amdMfma.getMDim(), amdMfma.getNDim()};
202-
203205
auto elemTypeOpt = amdMfma.getElementType();
204206
const char *typeName = "fp32";
205207
if (elemTypeOpt.has_value()) {
@@ -222,6 +224,19 @@ py::object layoutToGluon(Attribute layout) {
222224
toStdVector(ctaLayout.getCTAsPerCGA()),
223225
toStdVector(ctaLayout.getCTASplitNum()),
224226
toStdVector(ctaLayout.getCTAOrder()));
227+
} else if (auto paddedShared =
228+
dyn_cast<ttg::PaddedSharedEncodingAttr>(layout)) {
229+
auto ctaLayout = paddedShared.getCTALayout();
230+
std::vector<std::pair<unsigned, unsigned>> intervalPaddingPairs;
231+
for (auto [interval, padding] :
232+
llvm::zip(paddedShared.getIntervals(), paddedShared.getPaddings())) {
233+
intervalPaddingPairs.push_back({interval, padding});
234+
}
235+
return layouts.PaddedSharedLayout(intervalPaddingPairs,
236+
toStdVector(paddedShared.getOrder()),
237+
toStdVector(ctaLayout.getCTAsPerCGA()),
238+
toStdVector(ctaLayout.getCTASplitNum()),
239+
toStdVector(ctaLayout.getCTAOrder()));
225240
}
226241

227242
throw py::value_error("Unhandled encoding encountered");
@@ -338,6 +353,18 @@ void init_gluon_ir(py::module &&m) {
338353
ctx, version, warpsPerCta, tilesPerWarp, instrShape[0],
339354
instrShape[1], transposed, ctaLayout, elemType);
340355
})
356+
.def("get_padded_shared_layout",
357+
[](GluonOpBuilder &self, std::vector<unsigned> &intervals,
358+
std::vector<unsigned> &paddings, std::vector<unsigned> &order,
359+
std::vector<unsigned> &ctasPerCga,
360+
std::vector<unsigned> &ctaSplitNum,
361+
std::vector<unsigned> &ctaOrder) -> Attribute {
362+
auto ctx = self.getContext();
363+
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
364+
ctx, ctasPerCga, ctaSplitNum, ctaOrder);
365+
return ttg::PaddedSharedEncodingAttr::get(ctx, intervals, paddings,
366+
order, ctaLayout);
367+
})
341368
.def("get_nvmma_shared_layout",
342369
[](GluonOpBuilder &self, unsigned swizzleByteWidth,
343370
unsigned elementBitwidth, bool transposed, bool fp4Padded,

python/test/gluon/test_frontend.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -940,10 +940,14 @@ def libdevice_kernel():
940940
a = ttgl.full([4, 32], 1, ttgl.float32, layout)
941941
b = ttgl.full([4, 32], 2, ttgl.float32, layout)
942942
c = ttgl.full([4, 32], 4, ttgl.float32, layout)
943+
943944
libdevice.abs(a)
944945
libdevice.fast_dividef(a, b)
945946
libdevice.fma(a, b, c)
946947

948+
libdevice.isnan(a)
949+
libdevice.isinf(a)
950+
947951

948952
@pytest.mark.parametrize("target", ALL_TARGETS)
949953
def test_libdevice(target):
@@ -962,6 +966,14 @@ def test_libdevice(target):
962966
%0 = tt.extern_elementwise %cst_0 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>) -> tensor<4x32xf32, #blocked>
963967
%1 = tt.extern_elementwise %cst_0, %cst_2 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>, tensor<4x32xf32, #blocked>) -> tensor<4x32xf32, #blocked>
964968
%2 = tt.extern_elementwise %cst_0, %cst_2, %cst_4 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>, tensor<4x32xf32, #blocked>, tensor<4x32xf32, #blocked>) -> tensor<4x32xf32, #blocked>
969+
%3 = tt.extern_elementwise %cst_0 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>) -> tensor<4x32xi32, #blocked>
970+
%c0_i32 = arith.constant 0 : i32
971+
%cst_5 = arith.constant dense<0> : tensor<4x32xi32, #blocked>
972+
%4 = arith.cmpi ne, %3, %cst_5 : tensor<4x32xi32, #blocked>
973+
%5 = tt.extern_elementwise %cst_0 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>) -> tensor<4x32xi32, #blocked>
974+
%c0_i32_6 = arith.constant 0 : i32
975+
%cst_7 = arith.constant dense<0> : tensor<4x32xi32, #blocked>
976+
%6 = arith.cmpi ne, %5, %cst_7 : tensor<4x32xi32, #blocked>
965977
tt.return
966978
}
967979
}
@@ -1926,10 +1938,10 @@ def buffer_load_store_kernel(x, y):
19261938
mask = ttgl.full((64, 64), 1, tl.int1, layout=layout)
19271939
other = ttgl.full((64, 64), 1.0, tl.float32, layout=layout)
19281940
a = ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1929-
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1941+
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.cs')
19301942

19311943
a = ttgl.amd.cdna4.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1932-
ttgl.amd.cdna4.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1944+
ttgl.amd.cdna4.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.cs')
19331945

19341946

19351947
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
@@ -1951,9 +1963,9 @@ def test_buffer_load_store(target):
19511963
%cst_0 = arith.constant 1.000000e+00 : f32
19521964
%cst_1 = arith.constant dense<1.000000e+00> : tensor<64x64xf32, #blocked>
19531965
%3 = amdgpu.buffer_load %arg0[%2], %cst, %cst_1 cacheModifier = ca : tensor<64x64xf32, #blocked>
1954-
amdgpu.buffer_store %3, %arg1[%2], %cst cacheModifier = ca : tensor<64x64xf32, #blocked>
1966+
amdgpu.buffer_store %3, %arg1[%2], %cst cacheModifier = cs : tensor<64x64xf32, #blocked>
19551967
%4 = amdgpu.buffer_load %arg0[%2], %cst, %cst_1 cacheModifier = ca : tensor<64x64xf32, #blocked>
1956-
amdgpu.buffer_store %4, %arg1[%2], %cst cacheModifier = ca : tensor<64x64xf32, #blocked>
1968+
amdgpu.buffer_store %4, %arg1[%2], %cst cacheModifier = cs : tensor<64x64xf32, #blocked>
19571969
tt.return
19581970
}
19591971
}
@@ -1971,15 +1983,15 @@ def buffer_load_store_with_broadcast_kernel(x, y):
19711983

19721984
mask = ttgl.full((64, 1), 1, tl.int1, layout=layout)
19731985
a = ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1974-
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1986+
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.cs')
19751987

19761988
mask = ttgl.full((1, 64), 1, tl.int1, layout=layout)
19771989
a = ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1978-
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1990+
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.cs')
19791991

19801992
other = 1.0
19811993
a = ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1982-
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1994+
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.cs')
19831995

19841996

19851997
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
@@ -2003,19 +2015,19 @@ def test_buffer_load_store_with_broadcast(target):
20032015
%3 = tt.broadcast %cst_1 : tensor<64x1xi1, #blocked> -> tensor<64x64xi1, #blocked>
20042016
%4 = amdgpu.buffer_load %arg0[%2], %3, %cst_0 cacheModifier = ca : tensor<64x64xf32, #blocked>
20052017
%5 = tt.broadcast %cst_1 : tensor<64x1xi1, #blocked> -> tensor<64x64xi1, #blocked>
2006-
amdgpu.buffer_store %4, %arg1[%2], %5 cacheModifier = ca : tensor<64x64xf32, #blocked>
2018+
amdgpu.buffer_store %4, %arg1[%2], %5 cacheModifier = cs : tensor<64x64xf32, #blocked>
20072019
%true_2 = arith.constant true
20082020
%cst_3 = arith.constant dense<true> : tensor<1x64xi1, #blocked>
20092021
%6 = tt.broadcast %cst_3 : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked>
20102022
%7 = amdgpu.buffer_load %arg0[%2], %6, %cst_0 cacheModifier = ca : tensor<64x64xf32, #blocked>
20112023
%8 = tt.broadcast %cst_3 : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked>
2012-
amdgpu.buffer_store %7, %arg1[%2], %8 cacheModifier = ca : tensor<64x64xf32, #blocked>
2024+
amdgpu.buffer_store %7, %arg1[%2], %8 cacheModifier = cs : tensor<64x64xf32, #blocked>
20132025
%cst_4 = arith.constant 1.000000e+00 : f32
20142026
%9 = tt.broadcast %cst_3 : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked>
20152027
%cst_5 = arith.constant dense<1.000000e+00> : tensor<64x64xf32, #blocked>
20162028
%10 = amdgpu.buffer_load %arg0[%2], %9, %cst_5 cacheModifier = ca : tensor<64x64xf32, #blocked>
20172029
%11 = tt.broadcast %cst_3 : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked>
2018-
amdgpu.buffer_store %10, %arg1[%2], %11 cacheModifier = ca : tensor<64x64xf32, #blocked>
2030+
amdgpu.buffer_store %10, %arg1[%2], %11 cacheModifier = cs : tensor<64x64xf32, #blocked>
20192031
tt.return
20202032
}
20212033
}
@@ -2111,3 +2123,72 @@ def kernel():
21112123
}
21122124
}
21132125
""")
2126+
2127+
2128+
@gluon.jit
2129+
def padded_shared_layout_kernel():
2130+
padded_shared_layout: ttgl.constexpr = ttgl.PaddedSharedLayout(interval_padding_pairs=[[2, 1], [4, 2], [8, 4]],
2131+
order=[1, 0], ctas_per_cga=[1, 1],
2132+
cta_split_num=[1, 1], cta_order=[1, 0])
2133+
2134+
ttgl.allocate_shared_memory(ttgl.int32, [64, 64], padded_shared_layout)
2135+
2136+
2137+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
2138+
def test_padded_shared_layout(target):
2139+
# This test is used to test the construction of PaddedSharedEncodingAttr in the gluon.
2140+
module = run_parser(padded_shared_layout_kernel, target=target)
2141+
expecttest.assert_expected_inline(
2142+
anonymize_ir(module.str_nodebug()), """\
2143+
#shared = #ttg.padded_shared<[2:+1, 4:+2, 8:+4] {order = [1, 0]}>
2144+
#smem = #ttg.shared_memory
2145+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
2146+
tt.func public @padded_shared_layout_kernel() attributes {noinline = false} {
2147+
%0 = ttg.local_alloc : () -> !ttg.memdesc<64x64xi32, #shared, #smem, mutable>
2148+
tt.return
2149+
}
2150+
}
2151+
""")
2152+
2153+
2154+
@gluon.jit
2155+
def infer_layout_for_padded_shared_kernel():
2156+
layout: ttgl.constexpr = ttgl.PaddedSharedLayout(interval_padding_pairs=[[2, 1], [4, 2], [8, 4]], order=[2, 0, 1])
2157+
smem = ttgl.allocate_shared_memory(ttgl.int32, [32, 4, 32], layout)
2158+
2159+
reshaped = smem.permute((1, 0, 2))
2160+
"""
2161+
permute is [1 0 2], which means
2162+
old 1 to new 0
2163+
old 0 to new 1
2164+
old 2 to new 2
2165+
so inverseMapping[0] = 1, inverseMapping[1] = 0, inverseMapping[2] = 2
2166+
2167+
order in srcEnc is [2, 0, 1]
2168+
thus the order in dstEnc are:
2169+
newOrder[0] = inverseMapping[srcEncOrder[0]] = 2
2170+
newOrder[1] = inverseMapping[srcEncOrder[1]] = 1
2171+
newOrder[2] = inverseMapping[srcEncOrder[2]] = 0
2172+
"""
2173+
ttgl.static_assert(
2174+
reshaped.layout == ttgl.PaddedSharedLayout(interval_padding_pairs=[(2, 1), (4, 2), (8, 4)], order=[2, 1, 0]))
2175+
2176+
2177+
@pytest.mark.parametrize("target", ALL_TARGETS)
2178+
def test_infer_layout_for_padded_shared(target):
2179+
# This test is used to test the conversion to gluon object PaddedSharedLayout from PaddedSharedEncodingAttr.
2180+
# This conversion is in layoutToGluon and ttgl.permute will finally use it.
2181+
module = run_parser(infer_layout_for_padded_shared_kernel, target=target)
2182+
expecttest.assert_expected_inline(
2183+
anonymize_ir(module.str_nodebug()), """\
2184+
#shared = #ttg.padded_shared<[2:+1, 4:+2, 8:+4] {order = [2, 0, 1]}>
2185+
#shared1 = #ttg.padded_shared<[2:+1, 4:+2, 8:+4] {order = [2, 1, 0]}>
2186+
#smem = #ttg.shared_memory
2187+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
2188+
tt.func public @infer_layout_for_padded_shared_kernel() attributes {noinline = false} {
2189+
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x4x32xi32, #shared, #smem, mutable>
2190+
%1 = ttg.memdesc_trans %0 {order = array<i32: 1, 0, 2>} : !ttg.memdesc<32x4x32xi32, #shared, #smem, mutable> -> !ttg.memdesc<4x32x32xi32, #shared1, #smem, mutable>
2191+
tt.return
2192+
}
2193+
}
2194+
""")

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,3 +1685,30 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B
16851685
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
16861686
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[
16871687
"ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"]
1688+
1689+
1690+
@pytest.mark.interpreter
1691+
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
1692+
def test_tensor_descriptor_store_downcast(dtype_str, device):
1693+
1694+
@triton.jit
1695+
def kernel(desc, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
1696+
moffset = tl.program_id(axis=0) * M_BLOCK
1697+
noffset = tl.program_id(axis=1) * N_BLOCK
1698+
midx = moffset + tl.arange(0, M_BLOCK)[:, None]
1699+
nidx = noffset + tl.arange(0, N_BLOCK)[None, :]
1700+
val_f32 = (midx * N + nidx).to(tl.float32)
1701+
# implicit downcast in the store.
1702+
desc.store([moffset, noffset], val_f32)
1703+
1704+
M, N = 32, 128
1705+
torch_dtype = getattr(torch, dtype_str)
1706+
M_BLOCK = 8
1707+
N_BLOCK = 32
1708+
grid_m = M // M_BLOCK
1709+
grid_n = N // N_BLOCK
1710+
out = torch.empty((M, N), dtype=torch_dtype, device=device)
1711+
desc = TensorDescriptor(out, out.shape, out.stride(), [M_BLOCK, N_BLOCK])
1712+
kernel[(grid_m, grid_n)](desc, M, N, M_BLOCK=M_BLOCK, N_BLOCK=N_BLOCK)
1713+
ref = torch.arange(M * N, dtype=torch.float32, device=device).reshape(M, N).to(torch_dtype)
1714+
torch.testing.assert_close(out, ref)

0 commit comments

Comments
 (0)