Skip to content

Commit 9a1750c

Browse files
authored
[TRTLLM-9493][noop] Refactor fusedMoeCommKernels to enable code sharing (#9922)
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent e0a4b72 commit 9a1750c

File tree

5 files changed

+471
-290
lines changed

5 files changed

+471
-290
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/*
2+
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <cuda_runtime.h>
19+
#include <stdint.h>
20+
21+
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
22+
23+
namespace tensorrt_llm
24+
{
25+
namespace kernels
26+
{
27+
28+
// ============================================================================
29+
// Address Conversion Utilities
30+
// ============================================================================
31+
32+
static __device__ __forceinline__ uint32_t __as_ptr_smem(void const* __ptr)
33+
{
34+
// Consider adding debug asserts here.
35+
return static_cast<uint32_t>(__cvta_generic_to_shared(__ptr));
36+
}
37+
38+
static __device__ __forceinline__ uint64_t __as_ptr_gmem(void const* __ptr)
39+
{
40+
// Consider adding debug asserts here.
41+
return static_cast<uint64_t>(__cvta_generic_to_global(__ptr));
42+
}
43+
44+
// ============================================================================
45+
// Memory Fence Operations
46+
// ============================================================================
47+
48+
__device__ __forceinline__ void fence_release_sys()
49+
{
50+
asm volatile("fence.release.sys;" : : : "memory");
51+
}
52+
53+
// ============================================================================
54+
// Memory Barrier Operations (mbarrier)
55+
// ============================================================================
56+
57+
__device__ __forceinline__ void mbarrier_init(uint64_t* addr, uint32_t const& count)
58+
{
59+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
60+
asm("mbarrier.init.shared.b64 [%0], %1;" : : "r"(__as_ptr_smem(addr)), "r"(count) : "memory");
61+
#endif
62+
}
63+
64+
__device__ __forceinline__ void mbarrier_expect_tx(uint64_t* addr, const uint32_t txCount)
65+
{
66+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
67+
asm("mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;"
68+
:
69+
: "r"(__as_ptr_smem(addr)), "r"(txCount)
70+
: "memory");
71+
#endif
72+
}
73+
74+
__device__ __forceinline__ uint64_t mbarrier_arrive(uint64_t* addr)
75+
{
76+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
77+
uint64_t state;
78+
asm("mbarrier.arrive.shared.b64 %0, [%1];" : "=l"(state) : "r"(__as_ptr_smem(addr)) : "memory");
79+
return state;
80+
#else
81+
return 0;
82+
#endif
83+
}
84+
85+
__device__ __forceinline__ uint64_t mbarrier_arrive_expect_tx(uint64_t* addr, const uint32_t txCount)
86+
{
87+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
88+
uint64_t state;
89+
asm("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 %0, [%1], %2;"
90+
: "=l"(state)
91+
: "r"(__as_ptr_smem(addr)), "r"(txCount)
92+
: "memory");
93+
return state;
94+
#else
95+
return 0;
96+
#endif
97+
}
98+
99+
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint64_t* addr, uint32_t const& phaseParity)
100+
{
101+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
102+
uint32_t waitComplete;
103+
asm("{\n\t .reg .pred P_OUT; \n\t"
104+
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2;\n\t"
105+
"selp.b32 %0, 1, 0, P_OUT; \n"
106+
"}"
107+
: "=r"(waitComplete)
108+
: "r"(__as_ptr_smem(addr)), "r"(phaseParity)
109+
: "memory");
110+
return static_cast<bool>(waitComplete);
111+
#else
112+
return false;
113+
#endif
114+
}
115+
116+
// ============================================================================
117+
// Async Copy Operations (cp.async for SM80+)
118+
// ============================================================================
119+
120+
template <int COPY_SIZE = 4>
121+
__device__ __forceinline__ void ldgsts(int* dstShm, int const* srcMem, bool predGuard)
122+
{
123+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
124+
asm volatile(
125+
"{\n"
126+
" .reg .pred p;\n"
127+
" setp.ne.b32 p, %0, 0;\n"
128+
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
129+
"}\n" ::"r"((int) predGuard),
130+
"r"(__as_ptr_smem(dstShm)), "l"(__as_ptr_gmem(srcMem)), "n"(COPY_SIZE));
131+
#endif
132+
}
133+
134+
__device__ __forceinline__ void cp_async_commit_group()
135+
{
136+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
137+
asm volatile("cp.async.commit_group;" : : :);
138+
#endif
139+
}
140+
141+
template <int N = 0>
142+
__device__ __forceinline__ void cp_async_wait_group()
143+
{
144+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800
145+
asm volatile("cp.async.wait_group %0;" : : "n"(N) : "memory");
146+
#endif
147+
}
148+
149+
// ============================================================================
150+
// Bulk Async Copy Operations (cp.async.bulk for SM90+)
151+
// ============================================================================
152+
153+
__device__ __forceinline__ void cp_async_bulk_g2s(void* dstMem, void const* srcMem, int copySize, uint64_t* smemBar)
154+
{
155+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
156+
asm("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];"
157+
:
158+
: "r"(__as_ptr_smem(dstMem)), "l"(__as_ptr_gmem(srcMem)), "r"(copySize), "r"(__as_ptr_smem(smemBar))
159+
: "memory");
160+
#endif
161+
}
162+
163+
__device__ __forceinline__ void cp_async_bulk_s2g(void* dstMem, void const* srcMem, int copySize)
164+
{
165+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
166+
asm("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;"
167+
:
168+
: "l"(__as_ptr_gmem(dstMem)), "r"(__as_ptr_smem(srcMem)), "r"(copySize)
169+
: "memory");
170+
#endif
171+
}
172+
173+
__device__ __forceinline__ void cp_async_bulk_commit_group()
174+
{
175+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
176+
asm volatile("cp.async.bulk.commit_group;" : : :);
177+
#endif
178+
}
179+
180+
template <int N = 0>
181+
__device__ __forceinline__ void cp_async_bulk_wait_group()
182+
{
183+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
184+
asm volatile("cp.async.bulk.wait_group %0;" : : "n"(N) : "memory");
185+
#endif
186+
}
187+
188+
template <int N = 0>
189+
__device__ __forceinline__ void cp_async_bulk_wait_group_read()
190+
{
191+
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900
192+
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(N) : "memory");
193+
#endif
194+
}
195+
196+
// ============================================================================
197+
// Shared Memory Barrier Helpers
198+
// ============================================================================
199+
200+
__device__ __forceinline__ void initSmemBar(uint64_t* smemBar, int laneId)
201+
{
202+
if (laneId == 0)
203+
{
204+
mbarrier_init(smemBar, WARP_SIZE);
205+
}
206+
__syncwarp();
207+
}
208+
209+
__device__ __forceinline__ void smemBarWait(uint64_t* smemBar, uint32_t* phaseParity)
210+
{
211+
while (!mbarrier_try_wait_parity(smemBar, *phaseParity))
212+
{
213+
}
214+
*phaseParity = 1 - *phaseParity;
215+
}
216+
217+
} // namespace kernels
218+
} // namespace tensorrt_llm

0 commit comments

Comments
 (0)