Skip to content

Commit 0b6c9e3

Browse files
authored
bugfix: check if acc_logprob is defined before H2D copy. (#293)
1 parent d8e8bb5 commit 0b6c9e3

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

xllm/core/runtime/llm_worker_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ std::optional<ForwardOutput> LLMWorkerImpl::step(
164164
// beam search kernel
165165
BeamSearchOutput beam_search_output;
166166
if (concated_sampling_params.use_beam_search &&
167-
inputs.acc_logprob.numel() > 0) {
167+
inputs.acc_logprob.defined() && inputs.acc_logprob.numel() > 0) {
168168
beam_search_output = beam_searcher_->forward(inputs.acc_logprob,
169169
sample_output.top_tokens,
170170
sample_output.top_logprobs);

xllm/core/runtime/worker_impl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,10 @@ void WorkerImpl::prepare_work_before_execute(
454454
}
455455
processed_inputs.concated_sampling_params =
456456
inputs.concated_sampling_params.to(device_, dtype_);
457-
processed_inputs.acc_logprob =
458-
inputs.acc_logprob.to(torch::kFloat32).to(device_);
457+
if (inputs.acc_logprob.defined()) {
458+
processed_inputs.acc_logprob =
459+
inputs.acc_logprob.to(torch::kFloat32).to(device_);
460+
}
459461
auto ret = prepare_stream_->synchronize();
460462
}
461463

0 commit comments

Comments
 (0)