Skip to content

Commit 609e327

Browse files
Merge commit '5d84a9122b519251d1453fc7e7f31e2e304dc1d6'
2 parents 633d32d + 5d84a91 commit 609e327

File tree

25 files changed

+300
-383
lines changed

25 files changed

+300
-383
lines changed

CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ if(NOT CMAKE_BUILD_TYPE)
8989
set(CMAKE_BUILD_TYPE "Release")
9090
endif()
9191

92-
if(NOT WIN32)
93-
find_library(TERMINFO_LIBRARY tinfo)
94-
endif()
95-
9692
if(TRITON_BUILD_UT)
9793
# This is an aggregate target for all unit tests.
9894
add_custom_target(TritonUnitTests)

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -528,32 +528,6 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
528528
triton::gpu::PaddedSharedEncodingAttr layout,
529529
unsigned bitwidth, Value smemOffset, bool offsetInBytes);
530530

531-
// Emits IR to load data from shared memory into registers, or to store data
532-
// from registers into shared memory.
533-
//
534-
// You supply perVectorCallback, which is called once per group of register
535-
// elements to transfer. You can use this callback to emit IR to load or store
536-
// data from or to shared memory.
537-
//
538-
// elemLlvmTy should be dstTy's element type converted to an LLVM-dialect type.
539-
//
540-
// If maxVecElems is provided, we won't vectorize more than this many elements.
541-
//
542-
// Returns true on success.
543-
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
544-
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
545-
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
546-
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
547-
const TargetInfoBase &target,
548-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
549-
550-
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
551-
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
552-
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
553-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
554-
Value laneId, Value warpId,
555-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
556-
557531
// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
558532
// We might want to merge them at some point, but having to support
559533
// ldmatrix.trans makes the code in lowerLdStMatrix a bit specific

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -706,110 +706,6 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
706706
maybeMaxVecElems, localLoadOp);
707707
}
708708

709-
bool emitTransferBetweenRegistersAndShared(
710-
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
711-
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
712-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
713-
Value laneId, Value warpId,
714-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
715-
MLIRContext *ctx = rewriter.getContext();
716-
auto b = TritonLLVMOpBuilder(loc, rewriter);
717-
718-
StringAttr kBlock = str_attr("block");
719-
StringAttr kRegister = str_attr("register");
720-
StringAttr kLane = str_attr("lane");
721-
StringAttr kWarp = str_attr("warp");
722-
StringAttr kOffset = str_attr("offset");
723-
724-
auto shape = sharedTy.getShape();
725-
auto paddedEnc =
726-
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedTy.getEncoding());
727-
LinearLayout regToSharedLayout = LinearLayout::empty();
728-
if (paddedEnc) {
729-
const auto &sharedLL = paddedEnc.getLinearComponent();
730-
regToSharedLayout = regLayout.invertAndCompose(sharedLL);
731-
} else {
732-
auto sharedLL = triton::gpu::toLinearLayout(sharedTy);
733-
regToSharedLayout = regLayout.invertAndCompose(sharedLL);
734-
}
735-
736-
// TODO(jlebar): We don't currently support loading from shared memory in a
737-
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
738-
if (regToSharedLayout.hasInDim(kBlock) &&
739-
regToSharedLayout.hasOutDim(kBlock) &&
740-
!regToSharedLayout.isTrivialOver({kBlock})) {
741-
return false;
742-
}
743-
744-
// Determine how many consecutive registers map to consecutive shmem elements
745-
// in out-dimension offsetN. This is our load instruction's vector width.
746-
//
747-
// It's OK if the vector width we choose here is wider than the hardware
748-
// supports; LLVM will legalize it.
749-
int vecElems =
750-
std::min({regToSharedLayout.getNumConsecutiveInOut(),
751-
maxVecElems.value_or(std::numeric_limits<int>::max())});
752-
if (paddedEnc) {
753-
vecElems = std::min(vecElems, int(paddedEnc.getMinInterval()));
754-
}
755-
756-
auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1;
757-
Value blockId =
758-
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);
759-
760-
int numElems = regToSharedLayout.getInDimSize(kRegister);
761-
auto vecTy = vec_ty(elemLlvmTy, vecElems);
762-
SmallVector<uint32_t> regIds;
763-
for (int i = 0; i < numElems / vecElems; i++) {
764-
regIds.push_back(i * vecElems);
765-
}
766-
767-
auto smemBase = smemObj.getBase();
768-
769-
auto indicesVec = applyLinearLayoutVec(loc, rewriter, regToSharedLayout,
770-
{{kRegister, b.i32_val(0)},
771-
{kLane, laneId},
772-
{kWarp, warpId},
773-
{kBlock, blockId}},
774-
regIds);
775-
776-
// Compute affine offset given by memdesc_subslice
777-
auto offset = smemObj.getShmemOffset(loc, rewriter, sharedTy);
778-
SmallVector<Value> vecAddrVec;
779-
for (auto &indices : indicesVec) {
780-
Value smemOffset = indices[0].second;
781-
smemOffset = b.xor_(smemOffset, offset);
782-
if (paddedEnc) {
783-
// Apply the offset needed for padding.
784-
auto bitwidth = elemLlvmTy.getIntOrFloatBitWidth();
785-
Value padOffset = emitPadding(loc, rewriter, paddedEnc, bitwidth,
786-
smemOffset, /*offsetInBytes=*/false);
787-
smemOffset = b.add(smemOffset, padOffset);
788-
}
789-
auto vecAddr = b.gep(smemBase.getType(), elemLlvmTy, smemBase, smemOffset,
790-
LLVM::GEPNoWrapFlags::inbounds);
791-
vecAddrVec.push_back(vecAddr);
792-
}
793-
794-
for (Value &vecAddr : vecAddrVec) {
795-
perVectorCallback(vecTy, vecAddr);
796-
}
797-
return true;
798-
}
799-
800-
bool emitTransferBetweenRegistersAndShared(
801-
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
802-
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
803-
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
804-
const TargetInfoBase &target,
805-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
806-
auto regLayout = triton::gpu::toLinearLayout(registerTy);
807-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
808-
return emitTransferBetweenRegistersAndShared(
809-
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
810-
target, laneId, warpId, perVectorCallback);
811-
}
812-
813709
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
814710
RewriterBase &rewriter) {
815711
assert(bool(llvmStruct) && "can not unpack null values");

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,31 @@ bool isConvertTrivial(ConvertLayoutOp op) {
7373
// Canonicalizer
7474
//===----------------------------------------------------------------------===//
7575

76+
// tmem_store(cvt) -> tmem_store
77+
struct CanonicalizeConvertFromTMEMStore
78+
: public mlir::OpRewritePattern<nvidia_gpu::TMEMStoreOp> {
79+
using OpRewritePattern::OpRewritePattern;
80+
81+
mlir::LogicalResult
82+
matchAndRewrite(nvidia_gpu::TMEMStoreOp op,
83+
PatternRewriter &rewriter) const override {
84+
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
85+
if (!convert)
86+
return failure();
87+
88+
// bail for incompatible layouts
89+
auto cvtSrcType = convert.getSrc().getType();
90+
if (!nvidia_gpu::isDistributedLayoutTMemCompatible(
91+
op.getOperation(), cvtSrcType, op.getDst().getType())) {
92+
return failure();
93+
}
94+
95+
rewriter.modifyOpInPlace(
96+
op, [&]() { op.getSrcMutable().assign(convert.getSrc()); });
97+
return mlir::success();
98+
}
99+
};
100+
76101
// reshape(cvt) -> reshape
77102
struct CanonicalizeConvertFromReshape
78103
: public mlir::OpRewritePattern<triton::ReshapeOp> {
@@ -373,6 +398,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
373398
patterns.add<CanonicalizeConvertFromAlloc>(context);
374399
patterns.add<CanonicalizeConvertFromLocalStore>(context);
375400
patterns.add<CanonicalizeConvertFromSplit>(context);
401+
patterns.add<CanonicalizeConvertFromTMEMStore>(context);
376402
}
377403

378404
LogicalResult Fp4ToFpOp::verify() {

python/src/gluon_ir.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,13 @@ void init_gluon_ir(py::module &&m) {
763763
self.create<ttag::BufferLoadToLocalOp>(
764764
dest, ptr, offsets, mask, other, stride, cacheModifier);
765765
})
766+
.def("create_make_tensor_descriptor",
767+
[](TritonOpBuilder &self, Type resultTy, Value &base,
768+
std::vector<Value> &shape, std::vector<Value> &strides,
769+
tt::PaddingOption paddingOption) -> Value {
770+
return self.create<tt::MakeTensorDescOp>(resultTy, base, shape,
771+
strides, paddingOption);
772+
})
766773
.def("create_async_tdm_copy_global_to_local",
767774
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &indices,
768775
Value result) {

python/test/gluon/test_frontend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2763,12 +2763,12 @@ def test_amd_tdm(target):
27632763
%c128_i32 = arith.constant 128 : i32
27642764
%c128_i64 = arith.constant 128 : i64
27652765
%c1_i64 = arith.constant 1 : i64
2766-
%0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : <f16>, <tensor<16x64xf16>>
2766+
%0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : <f16>, <tensor<16x64xf16, #shared>>
27672767
%1 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
27682768
%c0_i32 = arith.constant 0 : i32
27692769
%c2_i32 = arith.constant 2 : i32
27702770
%true = arith.constant true
2771-
%2 = amdgpu.async_tdm_copy_global_to_local %0[%c0_i32, %c2_i32] into %1, %true : !tt.tensordesc<tensor<16x64xf16>> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
2771+
%2 = amdgpu.async_tdm_copy_global_to_local %0[%c0_i32, %c2_i32] into %1, %true : !tt.tensordesc<tensor<16x64xf16, #shared>> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
27722772
%3 = amdgpu.async_tdm_wait {num = 0 : i32}
27732773
%4 = ttg.local_load %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> tensor<16x64xf16, #blocked>
27742774
tt.return

python/test/unit/test_debuginfo.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,65 @@
11
import os
2-
import subprocess
32

4-
all_names = ["offsets", "pid", "block_start", "mask", "x", "y", "output"]
3+
import pytest
4+
import torch
5+
6+
import triton
7+
import triton.language as tl
8+
9+
10+
@triton.jit
11+
def add_kernel(
12+
x_ptr,
13+
y_ptr,
14+
output_ptr,
15+
n_elements,
16+
BLOCK_SIZE: tl.constexpr,
17+
):
18+
pid = tl.program_id(axis=0)
19+
block_start = pid * BLOCK_SIZE
20+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
21+
mask = offsets < n_elements
22+
x = tl.load(x_ptr + offsets, mask=mask)
23+
y = tl.load(y_ptr + offsets, mask=mask)
24+
output = x + y
25+
tl.store(output_ptr + offsets, output, mask=mask)
526

627

728
def checkDbgInfo(llir, hasDbgInfo):
829
assert hasDbgInfo == ('dbg_value' in llir)
9-
for name in all_names:
30+
for name in ["offsets", "pid", "block_start", "mask", "x", "y", "output"]:
1031
assert hasDbgInfo == ('!DILocalVariable(name: \"' + name + '\"' in llir)
1132

1233

13-
def test_triton_debuginfo_on():
14-
lineInfoKey = "TRITON_DISABLE_LINE_INFO"
15-
diLocalVarKey = "LLVM_EXTRACT_DI_LOCAL_VARIABLES"
34+
@pytest.mark.parametrize("lineInfoKey, diLocalVarKey, hasDbgInfo", [
35+
(None, None, False),
36+
# expect dbginfo based on parent proccess' TRITON_DISABLE_LINE_INFO
37+
(None, "1", "infer"),
38+
("0", "1", True),
39+
("1", "1", False),
40+
("0", "0", False),
41+
("1", "0", False),
42+
])
43+
def test_triton_debuginfo_on(lineInfoKey, diLocalVarKey, hasDbgInfo, device, monkeypatch):
44+
lineInfoKeyName = "TRITON_DISABLE_LINE_INFO"
45+
diLocalVarKeyName = "LLVM_EXTRACT_DI_LOCAL_VARIABLES"
46+
if lineInfoKey is not None:
47+
monkeypatch.setenv(lineInfoKeyName, lineInfoKey)
48+
if diLocalVarKey is not None:
49+
monkeypatch.setenv(diLocalVarKeyName, diLocalVarKey)
1650

1751
isEnvSet = lambda env, str: env.get(str, None) is not None
18-
hasOrigLineInfo = (not isEnvSet(os.environ, lineInfoKey)
19-
or os.environ[lineInfoKey].lower() not in ["on", "true", "1"])
20-
envs = [
21-
# expect no dbginfo if unset
22-
{lineInfoKey: None, diLocalVarKey: None, "hasDbgInfo": False},
23-
# expect dbginfo based on parent proccess' TRITON_DISABLE_LINE_INFO
24-
{lineInfoKey: None, diLocalVarKey: "1", "hasDbgInfo": hasOrigLineInfo},
25-
{lineInfoKey: "0", diLocalVarKey: "1", "hasDbgInfo": True},
26-
{lineInfoKey: "1", diLocalVarKey: "1", "hasDbgInfo": False},
27-
{lineInfoKey: "0", diLocalVarKey: "0", "hasDbgInfo": False},
28-
{lineInfoKey: "1", diLocalVarKey: "0", "hasDbgInfo": False},
29-
]
30-
31-
_run_test = lambda test_env: subprocess.run([
32-
"python3", os.path.dirname(os.path.realpath(__file__)) + "/test_debuginfo_helper.py"
33-
], env=test_env, capture_output=True, text=True)
34-
for env in envs:
35-
test_env = os.environ.copy()
36-
test_env["TRITON_ALWAYS_COMPILE"] = "1"
37-
for entry in env:
38-
if not isEnvSet(env, entry): continue
39-
test_env[entry] = str(env[entry])
40-
checkDbgInfo(str(_run_test(test_env).stdout), hasDbgInfo=env["hasDbgInfo"])
52+
if hasDbgInfo == "infer":
53+
hasDbgInfo = (not isEnvSet(os.environ, lineInfoKeyName)
54+
or os.environ[lineInfoKeyName].lower() not in ["on", "true", "1"])
55+
56+
size = 98432
57+
torch.manual_seed(0)
58+
x = torch.rand(size, device=device)
59+
y = torch.rand(size, device=device)
60+
output = torch.empty_like(x)
61+
n_elements = output.numel()
62+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
63+
add_kernel.device_caches.clear()
64+
h = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
65+
checkDbgInfo(h.asm['llir'], hasDbgInfo)

python/test/unit/test_debuginfo_helper.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

python/triton/experimental/gluon/language/_core.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,10 +509,6 @@ def warp_specialize(default_args, default_partition, worker_args, worker_partiti
509509
"""
510510
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
511511
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
512-
if not isinstance(default_args, tuple):
513-
default_args = (default_args, )
514-
if not isinstance(worker_args, tuple):
515-
worker_args = (worker_args, )
516512
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
517513
worker_num_regs, _generator)
518514

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,10 @@ def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy:
420420
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
421421
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
422422
num_partitions = len(worker_partitions)
423+
_check(isinstance(default_args, (tuple, ttgl.tuple)),
424+
lambda: f"default_args must be a tuple of arguments, but got {type(default_args)}")
425+
_check(isinstance(worker_args, (tuple, ttgl.tuple)),
426+
lambda: f"worker_args must be a tuple of arguments, but got {type(worker_args)}")
423427
assert num_partitions == len(
424428
worker_num_warps
425429
), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts"

0 commit comments

Comments
 (0)