Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions skyrl-tx/tests/tinker/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_sample_prompt_logprobs_with_microbatching():
sampling_params=sampling_params,
num_samples=1,
checkpoint_id="",
prompt_logprobs=True,
prompt_logprobs=i in {2, 4},
),
)
for i, tokens in enumerate(prompts)
Expand All @@ -484,15 +484,15 @@ def test_sample_prompt_logprobs_with_microbatching():
results = engine.process_sample_batch(reqs)

# Verify that each request got its correct prompt_logprobs
for i, tokens in enumerate(prompts):
request_id = f"req_{i}"
for i, (request_id, (_, sample_input)) in enumerate(reqs.items()):
result = results[request_id]

# Verify prompt_logprobs are returned
assert result.prompt_logprobs is not None, f"Request {request_id}: prompt_logprobs should not be None"

# Verify correct length
expected_length = len(tokens) - 1
assert (
len(result.prompt_logprobs) == expected_length
), f"Request {request_id}: expected {expected_length} prompt_logprobs, got {len(result.prompt_logprobs)}"
if not sample_input.prompt_logprobs:
# Verify prompt_logprobs is not returned
assert result.prompt_logprobs is None, f"Request {request_id}: prompt_logprobs should be None"
else:
# Verify correct length
expected_length = len(prompts[i]) - 1
assert (
len(result.prompt_logprobs) == expected_length
), f"Request {request_id}: expected {expected_length} prompt_logprobs, got {len(result.prompt_logprobs)}"
24 changes: 12 additions & 12 deletions skyrl-tx/tx/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,11 +617,9 @@ def process_sample_batch(
if not valid_requests:
return results

# Computes prompt_logprobs for the whole batch if any request asked for them
needs_prompt_logprobs = any(request_data.prompt_logprobs for (_, request_data) in valid_requests.values())

all_prompts = []
all_sampling_params = []
all_request_logprobs = []
all_adapter_indices = []
request_batch_slices = []

Expand All @@ -636,17 +634,18 @@ def process_sample_batch(
prompt_tokens = [token for chunk in request_data.prompt.chunks for token in chunk.tokens]
all_prompts.append(prompt_tokens)
all_sampling_params.append(request_data.sampling_params)
all_request_logprobs.append(request_data.prompt_logprobs)
all_adapter_indices.append(adapter_indices_batch[i])

request_batch_slices.append((request_id, model_id, request_start, len(all_prompts), request_data))
request_batch_slices.append((request_id, model_id, request_start, len(all_prompts)))

total_batch_size = len(all_prompts)
max_batch_size = (
self.config.sample_max_num_sequences if self.config.sample_max_num_sequences > 0 else total_batch_size
)
# Collect generated sequences and prompt logprobs across batches
all_sequences: list[types.GeneratedSequence] = []
all_prompt_logprobs: list[list[float]] = []
all_prompt_logprobs: list[list[float] | None] = []

with jax.set_mesh(self.mesh):
model = nnx.merge(self.graphdef, self.lora_params, self.non_lora_params)
Expand All @@ -657,7 +656,8 @@ def process_sample_batch(
sampling_params = pad(
all_sampling_params[batch_start:batch_end], max_batch_size, fill=all_sampling_params[batch_start]
)

# Compute prompt_logprobs if any request in this micro-batch needs them
needs_prompt_logprobs = any(all_request_logprobs[batch_start:batch_end])
# Pad sequences to same length within the batch to minimize memory usage.
# Also bin it so the JIT has to compile fewer kernels.
max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0))
Expand All @@ -682,15 +682,15 @@ def process_sample_batch(
result.logprobs[:batch_size],
)
)
if needs_prompt_logprobs and result.prompt_logprobs:
all_prompt_logprobs.extend(result.prompt_logprobs[:batch_size])
all_prompt_logprobs.extend(
result.prompt_logprobs[i] if request_logprobs else None
for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end])
)
Comment on lines +685 to +688
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for extending all_prompt_logprobs could lead to a TypeError if model.generate returns None for prompt_logprobs even when they were requested for the micro-batch (needs_prompt_logprobs is True). The previous code included a check for result.prompt_logprobs being truthy, which was removed, making the access to result.prompt_logprobs[i] potentially unsafe.

To make the code more robust, it's better to reintroduce a check for result.prompt_logprobs being non-None before attempting to index it. The suggested change below does this in a compact way.

Suggested change
all_prompt_logprobs.extend(
result.prompt_logprobs[i] if request_logprobs else None
for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end])
)
all_prompt_logprobs.extend(
result.prompt_logprobs[i] if request_logprobs and result.prompt_logprobs else None
for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end])
)


for request_id, _, start_idx, end_idx, request_data in request_batch_slices:
for request_id, _, start_idx, end_idx in request_batch_slices:
sequences = [all_sequences[i] for i in range(start_idx, end_idx)]
# Each of `num_samples` samples in a request share the same prompt; use the first's prompt logprobs
prompt_logprobs = (
all_prompt_logprobs[start_idx] if request_data.prompt_logprobs and all_prompt_logprobs else None
)
prompt_logprobs = all_prompt_logprobs[start_idx]
results[request_id] = types.SampleOutput(sequences=sequences, prompt_logprobs=prompt_logprobs)

return results
Expand Down
Loading