Skip to content

Commit ce5139e

Browse files
authored
Revert "Use trl.generation.VLLMGeneration instead of the separate vLLM logic"
This reverts commit 05eac2c.
1 parent 05eac2c commit ce5139e

File tree

3 files changed

+471
-413
lines changed

3 files changed

+471
-413
lines changed

tests/experimental/test_gold_trainer.py

Lines changed: 51 additions & 268 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from datasets import load_dataset
2020
from transformers import AutoTokenizer
2121

22-
from trl.experimental.gold import gold_trainer as gold_trainer_module
2322
from trl.experimental.gold.gold_trainer import GOLDTrainer, ULDLoss, build_teacher_inputs_from_texts
2423
from trl.experimental.utils import DataCollatorForChatML
2524

@@ -290,11 +289,58 @@ def pad_labels(labels, target_length):
290289
return labels + [-100] * (target_length - len(labels))
291290

292291

293-
def test_process_completions_to_buffer_left_pads_prompt_ids():
292+
def test_process_completions_to_buffer_left_pads_prompt_retokenization():
293+
class DummyBatch:
294+
def __init__(self, input_ids):
295+
self.input_ids = input_ids
296+
297+
def to(self, device):
298+
self.input_ids = self.input_ids.to(device)
299+
return self
300+
294301
class RecordingTokenizer:
295302
pad_token_id = 0
296303
pad_token = "<pad>"
297304

305+
def __init__(self):
306+
self.padding_side = "right"
307+
self.calls = []
308+
self._prompt_ids = {
309+
"short": [11],
310+
"longer": [21, 22],
311+
}
312+
313+
def __call__(
314+
self,
315+
texts,
316+
return_tensors,
317+
padding,
318+
truncation,
319+
max_length,
320+
add_special_tokens,
321+
padding_side=None,
322+
):
323+
assert return_tensors == "pt"
324+
assert padding == "longest"
325+
assert not truncation
326+
assert max_length is None
327+
assert not add_special_tokens
328+
self.calls.append(padding_side)
329+
330+
side = padding_side or self.padding_side
331+
encoded = [torch.tensor(self._prompt_ids[text], dtype=torch.long) for text in texts]
332+
max_len = max(len(ids) for ids in encoded)
333+
334+
padded = []
335+
for ids in encoded:
336+
pad_width = max_len - len(ids)
337+
if pad_width:
338+
pad = torch.full((pad_width,), self.pad_token_id, dtype=torch.long)
339+
ids = torch.cat([pad, ids]) if side == "left" else torch.cat([ids, pad])
340+
padded.append(ids)
341+
342+
return DummyBatch(torch.stack(padded))
343+
298344
def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False):
299345
del skip_special_tokens, clean_up_tokenization_spaces
300346
return [" ".join(str(token) for token in sequence) for sequence in sequences]
@@ -312,282 +358,19 @@ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenizati
312358
on_policy_indices=[0],
313359
local_slice_indices=[0, 0],
314360
completion_ids=[[31], [41]],
315-
prompts_text_with_special=["short", "longer"],
316-
prompt_ids_list=[[11], [21, 22]],
317361
prompts_text=["short", "longer"],
362+
prompts_text_with_special=["short", "longer"],
318363
max_completion_length=1,
319364
)
320365

321366
buffered_inputs = trainer._buffered_inputs[0]
367+
assert trainer.processing_class.calls == ["left"]
368+
assert trainer.processing_class.padding_side == "right"
322369
assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[0, 11, 31], [21, 22, 41]], dtype=torch.long))
323370
assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[0, 1, 1], [1, 1, 1]], dtype=torch.long))
324371
assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, 31], [-100, -100, 41]]))
325372

326373

327-
def test_generate_on_policy_for_slices_uses_prompt_attention_mask_for_vllm_prompts():
328-
class RecordingVLLMGeneration:
329-
def __init__(self):
330-
self.prompts = None
331-
self.sync_calls = 0
332-
333-
def sync_weights(self):
334-
self.sync_calls += 1
335-
336-
def generate(self, prompts, images, num_generations):
337-
self.prompts = prompts
338-
assert images is None
339-
assert num_generations == 1
340-
return None, [[42]], None, None
341-
342-
class RecordingTokenizer:
343-
pad_token_id = 9
344-
pad_token = "<eos>"
345-
346-
def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False):
347-
del clean_up_tokenization_spaces
348-
decoded = []
349-
token_map = {5: "A", 6: "B", 9: "<eos>"}
350-
for sequence in sequences:
351-
tokens = []
352-
for token in sequence:
353-
token = int(token)
354-
if skip_special_tokens and token == 9:
355-
continue
356-
tokens.append(token_map[token])
357-
decoded.append(" ".join(tokens))
358-
return decoded
359-
360-
captured = {}
361-
362-
def capture_process_completions(
363-
slices,
364-
on_policy_indices,
365-
local_slice_indices,
366-
completion_ids,
367-
prompt_ids_list,
368-
prompts_text_with_special,
369-
prompts_text,
370-
max_completion_length,
371-
):
372-
captured["slices"] = slices
373-
captured["on_policy_indices"] = on_policy_indices
374-
captured["local_slice_indices"] = local_slice_indices
375-
captured["completion_ids"] = completion_ids
376-
captured["prompt_ids_list"] = prompt_ids_list
377-
captured["prompts_text"] = prompts_text
378-
captured["prompts_text_with_special"] = prompts_text_with_special
379-
captured["max_completion_length"] = max_completion_length
380-
381-
trainer = GOLDTrainer.__new__(GOLDTrainer)
382-
trainer.accelerator = SimpleNamespace(is_main_process=True)
383-
trainer.args = SimpleNamespace(report_to=[])
384-
trainer.processing_class = RecordingTokenizer()
385-
trainer.use_vllm = True
386-
trainer.vllm_generation = RecordingVLLMGeneration()
387-
trainer.vllm_sync_frequency = 1
388-
trainer._last_vllm_sync_step = -1
389-
trainer.state = SimpleNamespace(global_step=0)
390-
trainer.num_generations = 1
391-
trainer.generation_config = SimpleNamespace(max_new_tokens=1)
392-
trainer._process_completions_to_buffer = capture_process_completions
393-
394-
slices = [
395-
{
396-
"prompts": torch.tensor([[9, 9, 5, 9, 6]], dtype=torch.long),
397-
"prompt_attention_mask": torch.tensor([[0, 0, 1, 1, 1]], dtype=torch.long),
398-
}
399-
]
400-
401-
GOLDTrainer._generate_on_policy_for_slices(trainer, slices, [0])
402-
403-
assert trainer.vllm_generation.prompts == [[5, 9, 6]]
404-
assert trainer.vllm_generation.sync_calls == 1
405-
assert captured["completion_ids"] == [[42]]
406-
assert captured["prompt_ids_list"] == [[5, 9, 6]]
407-
assert captured["prompts_text"] == ["A B"]
408-
assert captured["prompts_text_with_special"] == ["A <eos> B"]
409-
410-
411-
def test_generate_on_policy_for_slices_reconstructs_prompt_with_special_tokens():
412-
class RecordingVLLMGeneration:
413-
def __init__(self):
414-
self.prompts = None
415-
self.sync_calls = 0
416-
417-
def sync_weights(self):
418-
self.sync_calls += 1
419-
420-
def generate(self, prompts, images, num_generations):
421-
self.prompts = prompts
422-
assert images is None
423-
assert num_generations == 1
424-
return None, [[42]], None, None
425-
426-
class RecordingTokenizer:
427-
pad_token_id = 0
428-
pad_token = "<pad>"
429-
430-
def __init__(self):
431-
self.truncation_side = "right"
432-
433-
def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False):
434-
del clean_up_tokenization_spaces
435-
token_map = {0: "<pad>", 5: "A", 6: "B", 13: "<special>", 42: "C"}
436-
decoded = []
437-
for sequence in sequences:
438-
tokens = []
439-
for token in sequence:
440-
token = int(token)
441-
if skip_special_tokens and token == 13:
442-
continue
443-
if token == 0:
444-
continue
445-
tokens.append(token_map[token])
446-
decoded.append(" ".join(tokens))
447-
return decoded
448-
449-
trainer = GOLDTrainer.__new__(GOLDTrainer)
450-
trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True)
451-
trainer.processing_class = RecordingTokenizer()
452-
trainer.args = SimpleNamespace(max_length=None, report_to=[])
453-
trainer.use_vllm = True
454-
trainer.vllm_generation = RecordingVLLMGeneration()
455-
trainer.vllm_sync_frequency = 1
456-
trainer._last_vllm_sync_step = -1
457-
trainer.state = SimpleNamespace(global_step=0)
458-
trainer.num_generations = 1
459-
trainer.generation_config = SimpleNamespace(max_new_tokens=1)
460-
trainer._buffered_inputs = [None]
461-
trainer._buffered_text_logs = [None]
462-
463-
slices = [
464-
{
465-
"slice": "original",
466-
"prompts": torch.tensor([[0, 0, 5, 13, 6]], dtype=torch.long),
467-
"prompt_attention_mask": torch.tensor([[0, 0, 1, 1, 1]], dtype=torch.long),
468-
}
469-
]
470-
471-
GOLDTrainer._generate_on_policy_for_slices(trainer, slices, [0])
472-
473-
buffered_inputs = trainer._buffered_inputs[0]
474-
assert trainer.vllm_generation.prompts == [[5, 13, 6]]
475-
assert trainer.vllm_generation.sync_calls == 1
476-
assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[5, 13, 6, 42]], dtype=torch.long))
477-
assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[1, 1, 1, 1]], dtype=torch.long))
478-
assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, -100, 42]], dtype=torch.long))
479-
assert buffered_inputs["original_prompt_text"] == ["A <special> B"]
480-
assert buffered_inputs["original_completion_text"] == ["C"]
481-
assert trainer._buffered_text_logs[0] == (["A B"], ["C"])
482-
483-
484-
def test_gold_trainer_init_defaults_vllm_max_model_length_to_max_length(monkeypatch):
485-
captured = {}
486-
487-
class DummyStudentModel:
488-
def __init__(self):
489-
self.config = SimpleNamespace(_name_or_path="student", vocab_size=17)
490-
self.generation_config = SimpleNamespace(eos_token_id=2)
491-
self.name_or_path = "student"
492-
493-
class DummyTeacherModel:
494-
def __init__(self):
495-
self.resized_to = None
496-
497-
def resize_token_embeddings(self, vocab_size):
498-
self.resized_to = vocab_size
499-
500-
class DummyProcessingClass:
501-
pad_token_id = 0
502-
503-
def fake_sft_init(
504-
self,
505-
model,
506-
args=None,
507-
data_collator=None,
508-
train_dataset=None,
509-
eval_dataset=None,
510-
processing_class=None,
511-
compute_metrics=None,
512-
callbacks=None,
513-
optimizers=None,
514-
preprocess_logits_for_metrics=None,
515-
peft_config=None,
516-
):
517-
del data_collator, train_dataset, eval_dataset, compute_metrics, callbacks, optimizers
518-
del preprocess_logits_for_metrics, peft_config
519-
self.model = model
520-
self.args = args
521-
self.processing_class = processing_class
522-
self.accelerator = SimpleNamespace(
523-
device=torch.device("cpu"),
524-
num_processes=1,
525-
prepare_model=lambda module, evaluation_mode=True: module,
526-
)
527-
self.is_deepspeed_enabled = False
528-
self.is_fsdp_enabled = False
529-
530-
class CapturingVLLMGeneration:
531-
def __init__(self, **kwargs):
532-
captured.update(kwargs)
533-
534-
monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init)
535-
monkeypatch.setattr(gold_trainer_module, "is_vllm_available", lambda: True)
536-
monkeypatch.setattr(gold_trainer_module, "VLLMGeneration", CapturingVLLMGeneration)
537-
538-
args = SimpleNamespace(
539-
model_init_kwargs=None,
540-
max_length=128,
541-
use_liger_kernel=False,
542-
teacher_model_init_kwargs=None,
543-
use_uld_loss=False,
544-
teacher_tokenizer_name_or_path=None,
545-
teacher_model_revision=None,
546-
disable_dropout=False,
547-
lmbda=1.0,
548-
beta=0.5,
549-
temperature=1.0,
550-
top_p=1.0,
551-
seq_kd=False,
552-
num_generations=1,
553-
use_transformers_paged=False,
554-
max_completion_length=16,
555-
top_k=0,
556-
log_completions=False,
557-
log_completions_steps=100,
558-
wandb_log_unique_prompts=True,
559-
num_completions_to_print=None,
560-
per_device_train_batch_size=1,
561-
gradient_accumulation_steps=1,
562-
use_vllm=True,
563-
vllm_mode="colocate",
564-
vllm_structured_outputs_regex=None,
565-
vllm_server_base_url=None,
566-
vllm_server_host="0.0.0.0",
567-
vllm_server_port=8001,
568-
vllm_group_port=51216,
569-
vllm_server_timeout=240.0,
570-
vllm_tensor_parallel_size=1,
571-
vllm_gpu_memory_utilization=0.2,
572-
vllm_max_model_length=None,
573-
vllm_enable_sleep_mode=False,
574-
vllm_model_impl="vllm",
575-
vllm_sync_frequency=1,
576-
)
577-
578-
teacher_model = DummyTeacherModel()
579-
GOLDTrainer(
580-
model=DummyStudentModel(),
581-
teacher_model=teacher_model,
582-
args=args,
583-
data_collator=object(),
584-
processing_class=DummyProcessingClass(),
585-
)
586-
587-
assert teacher_model.resized_to == 17
588-
assert captured["max_model_length"] == 128
589-
590-
591374
def test_alignment_groups_cover_all_tokens(llama_tokenizer, qwen_tokenizer):
592375
config = build_config()
593376
loss = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer)

0 commit comments

Comments
 (0)