Skip to content

[Bug]: DeekSeep V3.2 max_tokens larger than 8196 crashes #9932

@Shang-Pin

Description

@Shang-Pin

System Info

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

trtllm-serve $BASE_MODEL \
        --backend pytorch \
        --tp_size 8 \
        --ep_size 8 \
        --host 0.0.0.0 \
        --port 8001 \
        --max_batch_size 64 \
        --max_num_tokens 32768 \
        --max_seq_len 163840 \
        --extra_llm_api_options $EXTRA_CONFIG_FILENAME > $LOGDIR/out.log &

extra.yaml

cuda_graph_config:
  max_batch_size: 512
  enable_padding: true
enable_chunked_prefill: false
enable_attention_dp: false
disable_overlap_scheduler: false
print_iter_log: true
kv_cache_config:
  enable_block_reuse: false
  free_gpu_memory_fraction: 0.9
  tokens_per_block: 64
moe_config:
  backend: CUTLASS

python

def test():
...   t = time.time()
...   text = "Hello " * 150000
...   openai.completions.create(model="", prompt="\n\n<|User|>%s    <|Assistant|>      **</think>**" % text, max_tokens=100)
...   print(time.time() - t)

Expected behavior

Not crash

actual behavior

Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 2061, in _forward_step
outputs = forward(scheduled_requests, self.resource_manager,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 2049, in forward
return self.model_engine.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/utils.py", line 108, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2818, in forward
outputs = self._forward_step(inputs, gather_ids,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2883, in _forward_step
outputs = self.model_forward(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2870, in model_forward
return self.model.forward(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/models/modeling_deepseekv3.py", line 1694, in forward
return super().forward(attn_metadata=attn_metadata,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/models/modeling_speculative.py", line 596, in forward
hidden_states = self.model(
^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/models/modeling_deepseekv3.py", line 1593, in forward
hidden_states, residual = decoder_layer(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/models/modeling_deepseekv3.py", line 1226, in forward
hidden_states = self.self_attn(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/modules/attention.py", line 2120, in forward
self.forward_impl_with_dsa(position_ids,
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/modules/attention.py", line 1285, in forward_impl_with_dsa
self.forward_context_dsa(
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/modules/attention.py", line 1370, in forward_context_dsa
return self.forward_absorption_context(q,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/modules/attention.py", line 1891, in forward_absorption_context
torch.ops.trtllm.bmm_out(q_nope_t,
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1255, in call
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py", line 644, in adinplaceorview_impl
return self._opoverload.redispatch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 848, in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs) # type: ignore[return-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py", line 343, in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py", line 376, in wrapped_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py", line 27, in bmm_out
torch.bmm(a, b, out=out)
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasGemmStridedBatchedEx(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, (int)ldb, strideb, (void*)&fbeta, c, std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16BF, (int)ldc, stridec, (int)num_batches, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)

additional notes

@jhaotingc suspects that since our (TRTLLM)'s default chunk size (max-num-token) is 8192, setting to larger value may result to some gemm using shapes that's not supported (and not discovered) yet.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.

Metadata

Metadata

Assignees

Labels

Customized kernels<NV>Specialized/modified CUDA kernels in TRTLLM for LLM ops, beyond standard TRT. Dev & perf.bugSomething isn't workingstalewaiting for feedback

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions