Skip to content

Commit 22b5b2c

Browse files
Merge OpenAI Triton commit b116579 (#5575)
This PR changes the Triton base from fa5f79a to b116579 (Nov 17). Pass rate: 95.41%->95.42%
2 parents 3fc0945 + 06517f5 commit 22b5b2c

File tree

33 files changed

+2646
-180
lines changed

33 files changed

+2646
-180
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
integration-tests-amd:
1414
runs-on: ${{ matrix.runner }}
1515
timeout-minutes: 45
16-
continue-on-error: ${{ matrix.runner[1] == 'gfx90a' }}
16+
continue-on-error: ${{ matrix.runner[1] == 'gfx90a' || matrix.runner[0] == 'gfx950' }}
1717
strategy:
1818
matrix:
1919
runner: ${{ fromJson(inputs.matrix) }}

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
6262
auto funcTy = funcOp.getFunctionType();
6363
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
6464
bool isKernel = triton::isKernel(funcOp);
65-
if (isKernel) {
65+
if (isKernel && targetInfo.isCuda()) {
6666
for (auto i : llvm::seq(amendedInputTy.size())) {
6767
if (isa<TensorDescType>(amendedInputTy[i])) {
6868
funcOp.setArgAttr(i, "tt.nv_tma_desc",

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,9 +887,10 @@ LogicalResult getConvertBackwardSlice(
887887
queue.pop_back();
888888
if (!isa<RankedTensorType>(currentValue.getType()))
889889
continue;
890-
// Skip propagating through for op results for now.
890+
// Skip propagating through for op/while op results for now.
891891
// TODO: enable this based on needs.
892-
if (currentValue.getDefiningOp<scf::ForOp>())
892+
if (currentValue.getDefiningOp<scf::ForOp>() ||
893+
currentValue.getDefiningOp<scf::WhileOp>())
893894
return failure();
894895
if (failed(updateLayout(currentValue, encoding)))
895896
return failure();

python/src/ir.cc

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,42 @@ py::list getTensorDescMetadata(ModuleOp &mod) {
212212

213213
auto blockType = descTy.getBlockType();
214214
auto encoding = blockType.getEncoding();
215-
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
216-
auto swizzle = ttng::getTMASwizzleMode(nullptr, descTy);
217-
auto elemType = ttng::getTMAElementType(nullptr, descTy);
218-
assert(swizzle.has_value());
219-
assert(elemType.has_value());
220-
auto blockSize = ttng::getTMABlockShape(blockType, /*packedSize=*/false);
215+
221216
py::dict metadata;
222-
metadata["swizzle"] = *swizzle;
223-
metadata["elem_size"] = descTy.getBlockType().getElementTypeBitWidth() / 8;
224-
metadata["elem_type"] = *elemType;
225-
metadata["block_size"] =
226-
std::vector<int>(blockSize.begin(), blockSize.end());
227-
metadata["fp4_padded"] = mmaEncoding && mmaEncoding.getFp4Padded();
217+
if (isa<ttg::NVMMASharedEncodingAttr>(encoding)) {
218+
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
219+
auto swizzle = ttng::getTMASwizzleMode(nullptr, descTy);
220+
auto elemType = ttng::getTMAElementType(nullptr, descTy);
221+
assert(swizzle.has_value());
222+
assert(elemType.has_value());
223+
auto blockSize = ttng::getTMABlockShape(blockType, /*packedSize=*/false);
224+
metadata["swizzle"] = *swizzle;
225+
metadata["elem_size"] =
226+
descTy.getBlockType().getElementTypeBitWidth() / 8;
227+
metadata["elem_type"] = *elemType;
228+
metadata["block_size"] =
229+
std::vector<int>(blockSize.begin(), blockSize.end());
230+
metadata["fp4_padded"] = mmaEncoding && mmaEncoding.getFp4Padded();
231+
} else {
232+
auto blockShape = blockType.getShape();
233+
metadata["block_size"] =
234+
std::vector<int>(blockShape.begin(), blockShape.end());
235+
metadata["elem_bits"] = blockType.getElementTypeBitWidth();
236+
237+
if (auto paddedEnc = dyn_cast<ttg::PaddedSharedEncodingAttr>(encoding)) {
238+
py::list intervalPaddingPairs;
239+
for (auto [interval, padding] : llvm::zip_equal(
240+
paddedEnc.getIntervals(), paddedEnc.getPaddings())) {
241+
py::list pair;
242+
pair.append(interval);
243+
pair.append(padding);
244+
intervalPaddingPairs.append(pair);
245+
}
246+
metadata["interval_padding_pairs"] = intervalPaddingPairs;
247+
248+
auto blockShape = blockType.getShape();
249+
}
250+
}
228251
result.append(std::move(metadata));
229252
}
230253
return result;

python/src/specialize.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ static bool init_called = false;
3939
static PyObject *constexpr_cls = nullptr;
4040
static PyObject *jit_callable_cls = nullptr;
4141
static PyObject *tensor_descriptor_cls = nullptr;
42-
static PyObject *gluon_tensor_descriptor_cls = nullptr;
42+
static PyObject *nvidia_tensor_descriptor_cls = nullptr;
43+
static PyObject *amd_tensor_descriptor_cls = nullptr;
4344
static PyObject *canonicalize_dtype_fn = nullptr;
4445
static PyObject *canonicalize_ptr_dtype_fn = nullptr;
4546
static PyObject *torch_tensor_cls = nullptr;
@@ -123,8 +124,10 @@ bool init_globals() noexcept try {
123124
jit_callable_cls = import_from("triton.runtime.jit", "JITCallable");
124125
tensor_descriptor_cls =
125126
import_from("triton.tools.tensor_descriptor", "TensorDescriptor");
126-
gluon_tensor_descriptor_cls = import_from(
127+
nvidia_tensor_descriptor_cls = import_from(
127128
"triton.experimental.gluon.nvidia.hopper", "TensorDescriptor");
129+
amd_tensor_descriptor_cls =
130+
import_from("triton.experimental.gluon.amd.gfx1250", "TensorDescriptor");
128131

129132
auto m_canonicalize = py::module_::import("triton._utils");
130133
canonicalize_dtype_fn = import_from("triton._utils", "canonicalize_dtype");
@@ -442,9 +445,13 @@ void init_type_handler_cache() {
442445
handle_tensor_descriptor;
443446
}
444447
// GluonTensorDescriptor
445-
if (gluon_tensor_descriptor_cls &&
446-
PyType_Check(gluon_tensor_descriptor_cls)) {
447-
type_handler_cache[(PyTypeObject *)gluon_tensor_descriptor_cls] =
448+
if (nvidia_tensor_descriptor_cls &&
449+
PyType_Check(nvidia_tensor_descriptor_cls)) {
450+
type_handler_cache[(PyTypeObject *)nvidia_tensor_descriptor_cls] =
451+
handle_gluon_tensor_descriptor;
452+
}
453+
if (amd_tensor_descriptor_cls && PyType_Check(amd_tensor_descriptor_cls)) {
454+
type_handler_cache[(PyTypeObject *)amd_tensor_descriptor_cls] =
448455
handle_gluon_tensor_descriptor;
449456
}
450457
// constexpr
@@ -491,7 +498,12 @@ std::pair<py::object, py::object> specialize_arg(PyObject *backend,
491498
align);
492499
}
493500

494-
if (PyObject_IsInstance(arg, gluon_tensor_descriptor_cls)) {
501+
if (PyObject_IsInstance(arg, nvidia_tensor_descriptor_cls)) {
502+
return handle_gluon_tensor_descriptor(backend, arg, is_const,
503+
specialize_value, align);
504+
}
505+
506+
if (PyObject_IsInstance(arg, amd_tensor_descriptor_cls)) {
495507
return handle_gluon_tensor_descriptor(backend, arg, is_const,
496508
specialize_value, align);
497509
}

python/test/gluon/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,7 +1583,7 @@ def kernel(out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, a,
15831583
assert "ttng.tc_gen5_mma_scaled" in ttgir
15841584

15851585

1586-
@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires Ampere or newer")
1586+
@pytest.mark.xfail(not is_ampere_or_newer(), reason="Requires Ampere or newer", run=False)
15871587
def test_coalesced_layout():
15881588

15891589
@gluon.jit
@@ -1628,7 +1628,7 @@ def kernel(in_ptr, out_ptr, #
16281628
torch.testing.assert_close(output, ref)
16291629

16301630

1631-
@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires Ampere or newer")
1631+
@pytest.mark.xfail(not is_ampere_or_newer(), reason="Requires Ampere or newer", run=False)
16321632
def test_convert_auto_layout_to_coalesced_layout():
16331633

16341634
@gluon.jit

python/test/gluon/test_frontend.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3140,6 +3140,42 @@ def test_amd_tdm_load(target):
31403140
""")
31413141

31423142

3143+
@gluon.jit
3144+
def amd_host_tdm_load_kernel(desc):
3145+
buffer = ttgl.allocate_shared_memory(desc.dtype, shape=desc.block_shape, layout=desc.layout)
3146+
ttgl.amd.gfx1250.tdm.async_load(desc, offsets=[0, 2], dest=buffer)
3147+
3148+
ttgl.amd.gfx1250.tdm.async_wait(0)
3149+
buffer.load(layout=ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]))
3150+
3151+
3152+
@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
3153+
def test_amd_host_tdm_load(target):
3154+
3155+
ptr = MockTensor(ttgl.float16, shape=(32, 128))
3156+
layout = ttgl.PaddedSharedLayout.with_identity_for([[32, 4]], [16, 64], [1, 0])
3157+
desc = gluon.amd.gfx1250.TensorDescriptor.from_tensor(ptr, block_shape=(16, 64), layout=layout)
3158+
module = run_parser(amd_host_tdm_load_kernel, *make_args(desc), target)
3159+
expecttest.assert_expected_inline(
3160+
anonymize_ir(module.str_nodebug()), """\
3161+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
3162+
#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [16, 64]}>
3163+
#smem = #ttg.shared_memory
3164+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
3165+
tt.func public @amd_host_tdm_load_kernel(%arg0: !tt.tensordesc<tensor<16x64xf16, #shared>>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) attributes {noinline = false} {
3166+
%0 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
3167+
%c0_i32 = arith.constant 0 : i32
3168+
%c2_i32 = arith.constant 2 : i32
3169+
%true = arith.constant true
3170+
%1 = amdg.async_tdm_copy_global_to_local %arg0[%c0_i32, %c2_i32] into %0, %true : !tt.tensordesc<tensor<16x64xf16, #shared>> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
3171+
%2 = amdg.async_tdm_wait {num = 0 : i32}
3172+
%3 = ttg.local_load %0 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> tensor<16x64xf16, #blocked>
3173+
tt.return
3174+
}
3175+
}
3176+
""")
3177+
3178+
31433179
@gluon.jit
31443180
def amd_tdm_store_kernel(ptr):
31453181
SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0])

python/test/unit/tools/test_aot.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,26 @@ def kernel(C, A, B, M, N, K,
6666
tl.store(c_ptrs, c)
6767
"""
6868

69+
gluon_kernel_src = """
70+
from triton.experimental import gluon
71+
from triton.experimental.gluon import language as gl
72+
73+
@gluon.jit
74+
def kernel(
75+
C, A, B, M, N, K,
76+
stride_cm, stride_cn,
77+
stride_am, stride_ak,
78+
stride_bk, stride_bn,
79+
BLOCK_M: gl.constexpr,
80+
BLOCK_N: gl.constexpr,
81+
BLOCK_K: gl.constexpr
82+
):
83+
layout: gl.constexpr = gl.BlockedLayout(size_per_thread=[1], threads_per_warp=[64], warps_per_cta=[1], order=[0])
84+
offs = gl.arange(0, 64, layout=layout)
85+
a = gl.load(A + offs)
86+
gl.store(B + offs, a)
87+
"""
88+
6989
test_utils_src = """
7090
#include <cuda.h>
7191
#include <stdio.h>
@@ -672,3 +692,15 @@ def test_ttgir_to_spv():
672692
assert "OpCapability Kernel" in spv
673693
assert "LocalSize 128 1 1" in spv
674694
assert "SubgroupSize 32" in spv
695+
696+
697+
def test_gluon_kernel():
698+
if not is_hip():
699+
pytest.xfail("Gluon kernel is only supported on HIP")
700+
with tempfile.TemporaryDirectory() as tmp_dir:
701+
dtype = "fp16"
702+
BM, BN, BK = 16, 16, 16
703+
704+
kernel_path = write_triton_kernels(tmp_dir, gluon_kernel_src, kernel_utils_src)
705+
compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK)
706+
check_hasco_binary_str(tmp_dir, dtype)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from . import nvidia
2+
from . import amd
23
from ._runtime import constexpr_function, jit
34
from triton.language.core import must_use_result
45

5-
__all__ = ["constexpr_function", "jit", "must_use_result", "nvidia"]
6+
__all__ = ["constexpr_function", "jit", "must_use_result", "nvidia", "amd"]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import gfx1250
2+
3+
__all__ = ["gfx1250"]

0 commit comments

Comments
 (0)