Skip to content

Commit 910c070

Browse files
bo-nvsyuoni
andauthored
[None][fix] fix accuracy issue(cherry-pick #11157 and #9530) (#11222)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Signed-off-by: Bo Deng <deemod@nvidia.com> Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent d248aef commit 910c070

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ __global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues
120120
auto warp = cg::tiled_partition<WARP_SIZE>(block);
121121

122122
BaseType minScore = BaseType{-INFINITY};
123+
124+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
125+
cudaGridDependencySynchronize();
126+
#endif
127+
123128
for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum)
124129
{
125130
auto scoreOffset = tokenId * numExperts;
@@ -168,6 +173,10 @@ __global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues
168173
}
169174
}
170175
} // end for tokenId
176+
177+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
178+
cudaTriggerProgrammaticLaunchCompletion();
179+
#endif
171180
}
172181

173182
int nextPowerOfTwo(int num)

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939

4040
# Use TinyGEMM when the number of tokens is not larger than this threshold
4141
MIN_LATENCY_TINYGEMM_NUM_TOKENS = 128
42+
# Enable TinyGEMM optimization (disabled by default, set ENABLE_TINYGEMM=1 to enable)
43+
ENABLE_TINYGEMM = os.environ.get('ENABLE_TINYGEMM', '0') == '1'
4244

4345

4446
class AttentionBlock(Attention):
@@ -226,7 +228,7 @@ def _create_ideal_expert_load_balanced_logits(
226228
dtype=pretrained_config.torch_dtype)
227229

228230
def compute_gate_output(self, x: torch.Tensor) -> torch.Tensor:
229-
if get_sm_version() in [
231+
if ENABLE_TINYGEMM and get_sm_version() in [
230232
90, 100, 103
231233
] and x.shape[0] <= MIN_LATENCY_TINYGEMM_NUM_TOKENS:
232234
weight = self.gate.weight

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1530,7 +1530,8 @@ def previous_seq_slots_device():
15301530
num_draft_tokens = len(draft_tokens)
15311531
total_num_tokens = len(position_ids)
15321532
assert total_num_tokens <= self.max_num_tokens, (
1533-
"total_num_tokens should be less than or equal to max_num_tokens")
1533+
f"total_num_tokens ({total_num_tokens}) should be less than or equal to max_num_tokens ({self.max_num_tokens})"
1534+
)
15341535
# if exist requests that do not have previous batch, copy input_ids and draft_tokens
15351536
if num_tokens > 0:
15361537
input_ids = torch.tensor(input_ids,

0 commit comments

Comments
 (0)