diff --git a/skyrl-tx/tests/tinker/test_engine.py b/skyrl-tx/tests/tinker/test_engine.py index 2633a6088..4d91016d9 100644 --- a/skyrl-tx/tests/tinker/test_engine.py +++ b/skyrl-tx/tests/tinker/test_engine.py @@ -475,7 +475,7 @@ def test_sample_prompt_logprobs_with_microbatching(): sampling_params=sampling_params, num_samples=1, checkpoint_id="", - prompt_logprobs=True, + prompt_logprobs=i in {2, 4}, ), ) for i, tokens in enumerate(prompts) @@ -484,15 +484,15 @@ def test_sample_prompt_logprobs_with_microbatching(): results = engine.process_sample_batch(reqs) # Verify that each request got its correct prompt_logprobs - for i, tokens in enumerate(prompts): - request_id = f"req_{i}" + for i, (request_id, (_, sample_input)) in enumerate(reqs.items()): result = results[request_id] - # Verify prompt_logprobs are returned - assert result.prompt_logprobs is not None, f"Request {request_id}: prompt_logprobs should not be None" - - # Verify correct length - expected_length = len(tokens) - 1 - assert ( - len(result.prompt_logprobs) == expected_length - ), f"Request {request_id}: expected {expected_length} prompt_logprobs, got {len(result.prompt_logprobs)}" + if not sample_input.prompt_logprobs: + # Verify prompt_logprobs is not returned + assert result.prompt_logprobs is None, f"Request {request_id}: prompt_logprobs should be None" + else: + # Verify correct length + expected_length = len(prompts[i]) - 1 + assert ( + len(result.prompt_logprobs) == expected_length + ), f"Request {request_id}: expected {expected_length} prompt_logprobs, got {len(result.prompt_logprobs)}" diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index b3379c12d..4c40a8242 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -617,11 +617,9 @@ def process_sample_batch( if not valid_requests: return results - # Computes prompt_logprobs for the whole batch if any request asked for them - needs_prompt_logprobs = any(request_data.prompt_logprobs for (_, request_data) in valid_requests.values()) - all_prompts = [] all_sampling_params = [] + all_request_logprobs = [] all_adapter_indices = [] request_batch_slices = [] @@ -636,9 +634,10 @@ def process_sample_batch( prompt_tokens = [token for chunk in request_data.prompt.chunks for token in chunk.tokens] all_prompts.append(prompt_tokens) all_sampling_params.append(request_data.sampling_params) + all_request_logprobs.append(request_data.prompt_logprobs) all_adapter_indices.append(adapter_indices_batch[i]) - request_batch_slices.append((request_id, model_id, request_start, len(all_prompts), request_data)) + request_batch_slices.append((request_id, model_id, request_start, len(all_prompts))) total_batch_size = len(all_prompts) max_batch_size = ( @@ -646,7 +645,7 @@ def process_sample_batch( ) # Collect generated sequences and prompt logprobs across batches all_sequences: list[types.GeneratedSequence] = [] - all_prompt_logprobs: list[list[float]] = [] + all_prompt_logprobs: list[list[float] | None] = [] with jax.set_mesh(self.mesh): model = nnx.merge(self.graphdef, self.lora_params, self.non_lora_params) @@ -657,7 +656,8 @@ def process_sample_batch( sampling_params = pad( all_sampling_params[batch_start:batch_end], max_batch_size, fill=all_sampling_params[batch_start] ) - + # Compute prompt_logprobs if any request in this micro-batch needs them + needs_prompt_logprobs = any(all_request_logprobs[batch_start:batch_end]) # Pad sequences to same length within the batch to minimize memory usage. # Also bin it so the JIT has to compile fewer kernels. max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0)) @@ -682,15 +682,15 @@ def process_sample_batch( result.logprobs[:batch_size], ) ) - if needs_prompt_logprobs and result.prompt_logprobs: - all_prompt_logprobs.extend(result.prompt_logprobs[:batch_size]) + all_prompt_logprobs.extend( + result.prompt_logprobs[i] if request_logprobs else None + for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end]) + ) - for request_id, _, start_idx, end_idx, request_data in request_batch_slices: + for request_id, _, start_idx, end_idx in request_batch_slices: sequences = [all_sequences[i] for i in range(start_idx, end_idx)] # Each of `num_samples` samples in a request share the same prompt; use the first's prompt logprobs - prompt_logprobs = ( - all_prompt_logprobs[start_idx] if request_data.prompt_logprobs and all_prompt_logprobs else None - ) + prompt_logprobs = all_prompt_logprobs[start_idx] results[request_id] = types.SampleOutput(sequences=sequences, prompt_logprobs=prompt_logprobs) return results