Skip to content

Commit 9a0e0cb

Browse files
pchen7e2meta-codesync[bot]
authored andcommitted
[1/N][TLX-2cta] Introduce TTNG_MapToRemoteBufferOp (#637)
Summary: To be able to essentially call NV's "mapa" on an SMEM buffer (or a barrier living there), we need to open up this API to front end. This will make it possible to explicitly arrive a remote barrier, and in the future, read/write DSMEM. Note if the CTAId is the executing CTA, original src address will be returned. Marking this Op as `MemDescViewTrait` will automatically handle alias analysis like MemDescIndex ops etc. ``` % make test-lit ninja -C /data/users/pchen7e4/triton/build/cmake.linux-x86_64-cpython-3.11 check-triton-lit-tests ninja: Entering directory `/data/users/pchen7e4/triton/build/cmake.linux-x86_64-cpython-3.11' [0/1] Running the triton regression tests Testing Time: 7.81s Total Discovered Tests: 208 Passed : 207 (99.52%) Expectedly Failed: 1 (0.48%) % third_party/tlx/run_all.sh Need to build triton in this script? {y|n}n Run all LITs? {y|n}n Run core Triton python unit tests? {y|n}n Run all TLX unit tests? {y|n}y Running TLX Unit Tests ... ====================================================================================== 31 passed, 76 skipped in 19.55s ====================================================================================== Run TLX tutorial kernels (correctness|performance|no)? {c|p|n} c Verifying correctness of TLX tutorial kernels (all passing) ``` Pull Request resolved: #637 Reviewed By: htyu Differential Revision: D86244418 Pulled By: pchen7e2 fbshipit-source-id: f62f28aefe5630d81e27fa9395e2f973db72b015
1 parent b8dfa52 commit 9a0e0cb

File tree

9 files changed

+167
-5
lines changed

9 files changed

+167
-5
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,25 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
7070
let assemblyFormat = "attr-dict";
7171
}
7272

73+
def TTNG_MapToRemoteBufferOp : TTNG_Op<"map_to_remote_buffer", [Pure, MemDescViewTrait]> {
74+
let summary = "Map shared memory buffer to the corresponding buffer in the target CTA";
75+
let description = [{
76+
Given a shared memory buffer mem desc `src`, return a mem desc referring to the corresponding buffer in the specified
77+
target CTA.
78+
79+
`$ctaRank` refers to the unique CTA id in a cluster acorss all dims. e.g. For a 2x4 CTA cluster, a valid CTA rank
80+
will be 0~7.
81+
}];
82+
83+
let arguments = (ins TTG_MemDescType:$src, I32:$ctaRank);
84+
85+
let results = (outs TTG_MemDescType:$result);
86+
87+
let assemblyFormat = [{$src`,` $ctaRank attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];
88+
89+
let hasVerifier = 1;
90+
}
91+
7392
//
7493
// WarpGroupDot Op
7594
//

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,29 @@ namespace mlir {
3838
namespace triton {
3939
namespace nvidia_gpu {
4040

41+
LogicalResult MapToRemoteBufferOp::verify() {
42+
// src and result should have the same type except MemorySpace
43+
MemDescType localType = getSrc().getType();
44+
MemDescType remoteType = getResult().getType();
45+
if (!(localType.getShape() == remoteType.getShape() &&
46+
localType.getElementType() == remoteType.getElementType() &&
47+
localType.getEncoding() == remoteType.getEncoding() &&
48+
localType.getMutableMemory() == remoteType.getMutableMemory() &&
49+
localType.getAllocShape() == remoteType.getAllocShape())) {
50+
return emitOpError() << "Local MemDesc not matching Remote MemDesc: "
51+
<< localType << " vs " << remoteType;
52+
}
53+
if (!isa<SharedMemorySpaceAttr>(localType.getMemorySpace())) {
54+
return emitOpError() << "Invalid memory space for local MemDesc: "
55+
<< localType;
56+
}
57+
if (!isa<SharedClusterMemorySpaceAttr>(remoteType.getMemorySpace())) {
58+
return emitOpError() << "Invalid memory space for remote MemDesc: "
59+
<< remoteType;
60+
}
61+
return success();
62+
}
63+
4164
// -- WarpGroupDotOp --
4265
LogicalResult WarpGroupDotOp::inferReturnTypes(
4366
MLIRContext *context, std::optional<Location> location, ValueRange operands,

test/Conversion/tritonnvidiagpu_to_llvm.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
271271
tt.return
272272
}
273273
}
274+
275+
// -----
276+
277+
// CHECK-LABEL: map_smem_to_remote
278+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
279+
#smem = #ttg.shared_memory
280+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
281+
tt.func public @map_smem_to_remote(%arg: !ttg.memdesc<1xi64, #shared, #smem, mutable>) {
282+
%c1_i32 = arith.constant 1 : i32
283+
// CHECK: nvvm.mapa %{{.*}} : !llvm.ptr<3> -> !llvm.ptr<7>
284+
%0 = ttng.map_to_remote_buffer %arg, %c1_i32: !ttg.memdesc<1xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #ttng.shared_cluster_memory, mutable>
285+
tt.return
286+
}
287+
}

test/TritonNvidiaGPU/invalid.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
// RUN: triton-opt --split-input-file %s --verify-diagnostics
22

3+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
4+
#smem = #ttg.shared_memory
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
6+
tt.func public @map_smem_to_remote(%arg: !ttg.memdesc<1xi64, #shared, #smem, mutable>) {
7+
%c1_i32 = arith.constant 1 : i32
8+
// expected-error @+1 {{Invalid memory space for remote MemDesc}}
9+
%0 = ttng.map_to_remote_buffer %arg, %c1_i32: !ttg.memdesc<1xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable>
10+
tt.return
11+
}
12+
}
13+
14+
// -----
15+
316
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
417
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
518
tt.func public @alloc_tensor_memory() {

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "PatternTritonGPUOpToLLVM.h"
2626
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2727
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
28+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
2829
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2930

3031
using namespace mlir;
@@ -63,12 +64,55 @@ struct ClusterWaitOpConversion
6364
return success();
6465
}
6566
};
67+
68+
// lower MapToRemoteBufferOp
69+
struct MapToRemoteBufferOpConversion
70+
: public ConvertOpToLLVMPattern<triton::nvidia_gpu::MapToRemoteBufferOp> {
71+
using ConvertOpToLLVMPattern<
72+
triton::nvidia_gpu::MapToRemoteBufferOp>::ConvertOpToLLVMPattern;
73+
74+
LogicalResult
75+
matchAndRewrite(triton::nvidia_gpu::MapToRemoteBufferOp op, OpAdaptor adaptor,
76+
ConversionPatternRewriter &rewriter) const override {
77+
auto loc = op.getLoc();
78+
auto srcSmemObj = LLVM::getSharedMemoryObjectFromStruct(
79+
loc, adaptor.getSrc(),
80+
typeConverter->convertType(op.getSrc().getType().getElementType()),
81+
rewriter);
82+
auto srcSmemPtr = srcSmemObj.getBase();
83+
84+
auto ptrTy = cast<LLVM::LLVMPointerType>(srcSmemPtr.getType());
85+
assert(ptrTy.getAddressSpace() == 3 &&
86+
"Invalid src llvm addr space for MapToRemoteBufferOp");
87+
88+
// The result pointer is referring to a memory buffer living in a CTA
89+
// cluster, so it has a different memory space. NVVM::MapaOp verifies its
90+
// src and result ptr type, so we need to construct the result ptr type
91+
// from typeConverter output here
92+
LLVM::LLVMStructType convertedRetTy =
93+
cast<LLVM::LLVMStructType>(typeConverter->convertType(op.getType()));
94+
Type convertedPtrTy = convertedRetTy.getBody()[0];
95+
96+
// map an SMEM ptr in mem space 3 to a ptr in mem space 7
97+
auto remotePtr = rewriter.create<NVVM::MapaOp>(
98+
loc, convertedPtrTy, srcSmemPtr, adaptor.getCtaRank());
99+
100+
// everything stays the same except base ptr comparing to srcSmemObj
101+
auto dstSmemObj = SharedMemoryObject(
102+
remotePtr, srcSmemObj.getBaseElemType(), srcSmemObj.getOffsets());
103+
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
104+
rewriter.replaceOp(op, retVal);
105+
return success();
106+
}
107+
};
108+
66109
} // namespace
67110

68111
void mlir::triton::NVIDIA::populateClusterOpsToLLVMPatterns(
69112
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
70113
PatternBenefit benefit) {
71114
patterns.add<ClusterArriveOpConversion>(typeConverter, benefit);
72115
patterns.add<ClusterWaitOpConversion>(typeConverter, benefit);
116+
patterns.add<MapToRemoteBufferOpConversion>(typeConverter, benefit);
73117
return;
74118
}

third_party/tlx/dialect/triton_tlx.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,22 @@ void init_triton_tlx_ir(py::module &&m) {
554554
threadId = self.create<arith::IndexCastOp>(
555555
self.getBuilder().getI32Type(), threadId);
556556
return threadId;
557+
})
558+
.def("create_map_to_remote_buffer",
559+
[](TritonOpBuilder &self, Value &src,
560+
Value &clusterCTARank) -> Value {
561+
auto bufferType = cast<ttg::MemDescType>(src.getType());
562+
assert(
563+
isa<ttg::SharedMemorySpaceAttr>(bufferType.getMemorySpace()) &&
564+
"Input of MapToRemoteBuffer has to be local SMEM");
565+
auto newBufferType = ttg::MemDescType::get(
566+
bufferType.getShape(), bufferType.getElementType(),
567+
bufferType.getEncoding(),
568+
ttng::SharedClusterMemorySpaceAttr::get(self.getContext()),
569+
bufferType.getMutableMemory(), bufferType.getAllocShape());
570+
Value remoteBuf = self.create<ttng::MapToRemoteBufferOp>(
571+
newBufferType, src, clusterCTARank);
572+
return remoteBuf;
557573
});
558574
}
559575

third_party/tlx/language/tlx/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
CLCPipelineContext,
1616
async_token,
1717
)
18-
from .mem_ops import (local_alloc, local_view, local_slice, subslice, async_load, async_load_commit_group,
18+
from .mem_ops import (local_alloc, local_view, remote_view, local_slice, subslice, async_load, async_load_commit_group,
1919
async_load_wait_group, local_load, local_store, local_trans, local_reinterpret,
2020
async_descriptor_load, async_descriptor_store, async_descriptor_store_wait, fence_async_shared)
2121
from .barrier import (
@@ -70,6 +70,7 @@
7070
# mem_ops
7171
"local_alloc",
7272
"local_view",
73+
"remote_view",
7374
"local_slice",
7475
"subslice",
7576
"async_load",

third_party/tlx/language/tlx/mem_ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from . import types as tlx
44
from .utility import cuda_parse_arch
55
from .mma_ops import require_nv_mma_shared_layout
6+
from .types import storage_kind
67
from typing import Optional, Tuple, overload
78

89

@@ -183,6 +184,35 @@ def _buffered_tensor_getitem(self, buffer_idx):
183184
return local_view(self, buffer_idx, _semantic=self.type.semantic)
184185

185186

187+
@tl.builtin
188+
def remote_view(
189+
local_allocated_buffer: tlx.mbarrier,
190+
remote_cta_rank: int | tl.constexpr | tl.tensor,
191+
_semantic=None,
192+
) -> tlx.mbarrier:
193+
"""
194+
Returns a remote view of the buffer. This returns a remote buf handle living in a CTA in the same CTA cluster with the
195+
executing CTA.
196+
:arg local_allocated_buffer: the local buffer handle we start with
197+
:arg remote_cta_rank: unique ID of the remote CTA within the CTA cluster. This ID is across all dims, so e.g. for
198+
a cluster of shape [2, 4] a valid unique ID could be 0~7, including the executing CTA itself
199+
:returns: a remote view of the buffer, located at the same relative location, but just in a possibly different CTA
200+
"""
201+
assert isinstance(local_allocated_buffer, tlx.mbarrier), "remote_view only supports barrier for now"
202+
assert local_allocated_buffer.type.storage == storage_kind.smem, "remote_view requires local smem as input"
203+
if isinstance(remote_cta_rank, tl.constexpr) or isinstance(remote_cta_rank, int):
204+
remote_cta_rank_handle = _semantic._convert_elem_to_ir_value(tl._unwrap_if_constexpr(remote_cta_rank),
205+
require_i64=False)
206+
else:
207+
assert isinstance(
208+
remote_cta_rank, tl.tensor
209+
), f"`remote_cta_rank` is in type {type(remote_cta_rank)} (must be either `tl.tensor` or `tl.constexpr`)"
210+
remote_cta_rank_handle = remote_cta_rank.handle
211+
remote_buf_handle = _semantic.builder.create_map_to_remote_buffer(local_allocated_buffer.handle,
212+
remote_cta_rank_handle)
213+
return tlx.mbarrier(remote_buf_handle, 0, local_allocated_buffer.type.layout, storage_kind.smemCluster)
214+
215+
186216
tlx.buffered_tensor.__getitem__ = _buffered_tensor_getitem
187217
tlx.mbarrier.__getitem__ = _buffered_tensor_getitem
188218
tlx.clc_response.__getitem__ = _buffered_tensor_getitem

third_party/tlx/language/tlx/types.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import enum
55
from abc import abstractmethod
66
from triton._C.libtriton import ir
7+
78
from triton.language.semantic import TritonSemantic
89

910

@@ -287,9 +288,10 @@ class mbarrier(tl.base_value):
287288
"""
288289

289290
def __init__(self, handle, num: int, layout: Optional[swizzled_shared_layout_encoding],
290-
semantics: TritonSemantic = None):
291+
semantics: TritonSemantic = None, storage: storage_kind = storage_kind.smem):
292+
assert storage == storage_kind.smem or storage == storage_kind.smemCluster, "mbarrier requires storage to be smem or smemCluster"
291293
self.handle = handle
292-
self.type = mbarrier_type(num, layout, semantics)
294+
self.type = mbarrier_type(num, layout, semantics, storage)
293295
self.num = num
294296

295297
def _flatten_ir(self, handles) -> None:
@@ -305,8 +307,8 @@ def _unflatten_ir(self, handles, cursor):
305307

306308
class mbarrier_type(buffered_tensor_type):
307309

308-
def __init__(self, num: int, layout: Optional[swizzled_shared_layout_encoding], semantic: TritonSemantic):
309-
super().__init__(tl.int64, [1], num, storage_kind.smem, layout, semantic)
310+
def __init__(self, num: int, layout: Optional[swizzled_shared_layout_encoding], semantic: TritonSemantic, storage):
311+
super().__init__(tl.int64, [1], num, storage, layout, semantic)
310312

311313
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[mbarrier, int]:
312314
value = mbarrier(handles[cursor], self.num, self.layout, self.semantic)

0 commit comments

Comments
 (0)