File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -164,7 +164,7 @@ std::optional<ForwardOutput> LLMWorkerImpl::step(
164164 // beam search kernel
165165 BeamSearchOutput beam_search_output;
166166 if (concated_sampling_params.use_beam_search &&
167- inputs.acc_logprob .numel () > 0 ) {
167+ inputs.acc_logprob .defined () && inputs. acc_logprob . numel () > 0 ) {
168168 beam_search_output = beam_searcher_->forward (inputs.acc_logprob ,
169169 sample_output.top_tokens ,
170170 sample_output.top_logprobs );
Original file line number Diff line number Diff line change @@ -454,8 +454,10 @@ void WorkerImpl::prepare_work_before_execute(
454454 }
455455 processed_inputs.concated_sampling_params =
456456 inputs.concated_sampling_params .to (device_, dtype_);
457- processed_inputs.acc_logprob =
458- inputs.acc_logprob .to (torch::kFloat32 ).to (device_);
457+ if (inputs.acc_logprob .defined ()) {
458+ processed_inputs.acc_logprob =
459+ inputs.acc_logprob .to (torch::kFloat32 ).to (device_);
460+ }
459461 auto ret = prepare_stream_->synchronize ();
460462}
461463
You can’t perform that action at this time.
0 commit comments