1919from datasets import load_dataset
2020from transformers import AutoTokenizer
2121
22- from trl .experimental .gold import gold_trainer as gold_trainer_module
2322from trl .experimental .gold .gold_trainer import GOLDTrainer , ULDLoss , build_teacher_inputs_from_texts
2423from trl .experimental .utils import DataCollatorForChatML
2524
@@ -290,11 +289,58 @@ def pad_labels(labels, target_length):
290289 return labels + [- 100 ] * (target_length - len (labels ))
291290
292291
293- def test_process_completions_to_buffer_left_pads_prompt_ids ():
292+ def test_process_completions_to_buffer_left_pads_prompt_retokenization ():
293+ class DummyBatch :
294+ def __init__ (self , input_ids ):
295+ self .input_ids = input_ids
296+
297+ def to (self , device ):
298+ self .input_ids = self .input_ids .to (device )
299+ return self
300+
294301 class RecordingTokenizer :
295302 pad_token_id = 0
296303 pad_token = "<pad>"
297304
305+ def __init__ (self ):
306+ self .padding_side = "right"
307+ self .calls = []
308+ self ._prompt_ids = {
309+ "short" : [11 ],
310+ "longer" : [21 , 22 ],
311+ }
312+
313+ def __call__ (
314+ self ,
315+ texts ,
316+ return_tensors ,
317+ padding ,
318+ truncation ,
319+ max_length ,
320+ add_special_tokens ,
321+ padding_side = None ,
322+ ):
323+ assert return_tensors == "pt"
324+ assert padding == "longest"
325+ assert not truncation
326+ assert max_length is None
327+ assert not add_special_tokens
328+ self .calls .append (padding_side )
329+
330+ side = padding_side or self .padding_side
331+ encoded = [torch .tensor (self ._prompt_ids [text ], dtype = torch .long ) for text in texts ]
332+ max_len = max (len (ids ) for ids in encoded )
333+
334+ padded = []
335+ for ids in encoded :
336+ pad_width = max_len - len (ids )
337+ if pad_width :
338+ pad = torch .full ((pad_width ,), self .pad_token_id , dtype = torch .long )
339+ ids = torch .cat ([pad , ids ]) if side == "left" else torch .cat ([ids , pad ])
340+ padded .append (ids )
341+
342+ return DummyBatch (torch .stack (padded ))
343+
298344 def batch_decode (self , sequences , skip_special_tokens = False , clean_up_tokenization_spaces = False ):
299345 del skip_special_tokens , clean_up_tokenization_spaces
300346 return [" " .join (str (token ) for token in sequence ) for sequence in sequences ]
@@ -312,282 +358,19 @@ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenizati
312358 on_policy_indices = [0 ],
313359 local_slice_indices = [0 , 0 ],
314360 completion_ids = [[31 ], [41 ]],
315- prompts_text_with_special = ["short" , "longer" ],
316- prompt_ids_list = [[11 ], [21 , 22 ]],
317361 prompts_text = ["short" , "longer" ],
362+ prompts_text_with_special = ["short" , "longer" ],
318363 max_completion_length = 1 ,
319364 )
320365
321366 buffered_inputs = trainer ._buffered_inputs [0 ]
367+ assert trainer .processing_class .calls == ["left" ]
368+ assert trainer .processing_class .padding_side == "right"
322369 assert torch .equal (buffered_inputs ["input_ids" ], torch .tensor ([[0 , 11 , 31 ], [21 , 22 , 41 ]], dtype = torch .long ))
323370 assert torch .equal (buffered_inputs ["attention_mask" ], torch .tensor ([[0 , 1 , 1 ], [1 , 1 , 1 ]], dtype = torch .long ))
324371 assert torch .equal (buffered_inputs ["labels" ], torch .tensor ([[- 100 , - 100 , 31 ], [- 100 , - 100 , 41 ]]))
325372
326373
327- def test_generate_on_policy_for_slices_uses_prompt_attention_mask_for_vllm_prompts ():
328- class RecordingVLLMGeneration :
329- def __init__ (self ):
330- self .prompts = None
331- self .sync_calls = 0
332-
333- def sync_weights (self ):
334- self .sync_calls += 1
335-
336- def generate (self , prompts , images , num_generations ):
337- self .prompts = prompts
338- assert images is None
339- assert num_generations == 1
340- return None , [[42 ]], None , None
341-
342- class RecordingTokenizer :
343- pad_token_id = 9
344- pad_token = "<eos>"
345-
346- def batch_decode (self , sequences , skip_special_tokens = False , clean_up_tokenization_spaces = False ):
347- del clean_up_tokenization_spaces
348- decoded = []
349- token_map = {5 : "A" , 6 : "B" , 9 : "<eos>" }
350- for sequence in sequences :
351- tokens = []
352- for token in sequence :
353- token = int (token )
354- if skip_special_tokens and token == 9 :
355- continue
356- tokens .append (token_map [token ])
357- decoded .append (" " .join (tokens ))
358- return decoded
359-
360- captured = {}
361-
362- def capture_process_completions (
363- slices ,
364- on_policy_indices ,
365- local_slice_indices ,
366- completion_ids ,
367- prompt_ids_list ,
368- prompts_text_with_special ,
369- prompts_text ,
370- max_completion_length ,
371- ):
372- captured ["slices" ] = slices
373- captured ["on_policy_indices" ] = on_policy_indices
374- captured ["local_slice_indices" ] = local_slice_indices
375- captured ["completion_ids" ] = completion_ids
376- captured ["prompt_ids_list" ] = prompt_ids_list
377- captured ["prompts_text" ] = prompts_text
378- captured ["prompts_text_with_special" ] = prompts_text_with_special
379- captured ["max_completion_length" ] = max_completion_length
380-
381- trainer = GOLDTrainer .__new__ (GOLDTrainer )
382- trainer .accelerator = SimpleNamespace (is_main_process = True )
383- trainer .args = SimpleNamespace (report_to = [])
384- trainer .processing_class = RecordingTokenizer ()
385- trainer .use_vllm = True
386- trainer .vllm_generation = RecordingVLLMGeneration ()
387- trainer .vllm_sync_frequency = 1
388- trainer ._last_vllm_sync_step = - 1
389- trainer .state = SimpleNamespace (global_step = 0 )
390- trainer .num_generations = 1
391- trainer .generation_config = SimpleNamespace (max_new_tokens = 1 )
392- trainer ._process_completions_to_buffer = capture_process_completions
393-
394- slices = [
395- {
396- "prompts" : torch .tensor ([[9 , 9 , 5 , 9 , 6 ]], dtype = torch .long ),
397- "prompt_attention_mask" : torch .tensor ([[0 , 0 , 1 , 1 , 1 ]], dtype = torch .long ),
398- }
399- ]
400-
401- GOLDTrainer ._generate_on_policy_for_slices (trainer , slices , [0 ])
402-
403- assert trainer .vllm_generation .prompts == [[5 , 9 , 6 ]]
404- assert trainer .vllm_generation .sync_calls == 1
405- assert captured ["completion_ids" ] == [[42 ]]
406- assert captured ["prompt_ids_list" ] == [[5 , 9 , 6 ]]
407- assert captured ["prompts_text" ] == ["A B" ]
408- assert captured ["prompts_text_with_special" ] == ["A <eos> B" ]
409-
410-
411- def test_generate_on_policy_for_slices_reconstructs_prompt_with_special_tokens ():
412- class RecordingVLLMGeneration :
413- def __init__ (self ):
414- self .prompts = None
415- self .sync_calls = 0
416-
417- def sync_weights (self ):
418- self .sync_calls += 1
419-
420- def generate (self , prompts , images , num_generations ):
421- self .prompts = prompts
422- assert images is None
423- assert num_generations == 1
424- return None , [[42 ]], None , None
425-
426- class RecordingTokenizer :
427- pad_token_id = 0
428- pad_token = "<pad>"
429-
430- def __init__ (self ):
431- self .truncation_side = "right"
432-
433- def batch_decode (self , sequences , skip_special_tokens = False , clean_up_tokenization_spaces = False ):
434- del clean_up_tokenization_spaces
435- token_map = {0 : "<pad>" , 5 : "A" , 6 : "B" , 13 : "<special>" , 42 : "C" }
436- decoded = []
437- for sequence in sequences :
438- tokens = []
439- for token in sequence :
440- token = int (token )
441- if skip_special_tokens and token == 13 :
442- continue
443- if token == 0 :
444- continue
445- tokens .append (token_map [token ])
446- decoded .append (" " .join (tokens ))
447- return decoded
448-
449- trainer = GOLDTrainer .__new__ (GOLDTrainer )
450- trainer .accelerator = SimpleNamespace (device = torch .device ("cpu" ), is_main_process = True )
451- trainer .processing_class = RecordingTokenizer ()
452- trainer .args = SimpleNamespace (max_length = None , report_to = [])
453- trainer .use_vllm = True
454- trainer .vllm_generation = RecordingVLLMGeneration ()
455- trainer .vllm_sync_frequency = 1
456- trainer ._last_vllm_sync_step = - 1
457- trainer .state = SimpleNamespace (global_step = 0 )
458- trainer .num_generations = 1
459- trainer .generation_config = SimpleNamespace (max_new_tokens = 1 )
460- trainer ._buffered_inputs = [None ]
461- trainer ._buffered_text_logs = [None ]
462-
463- slices = [
464- {
465- "slice" : "original" ,
466- "prompts" : torch .tensor ([[0 , 0 , 5 , 13 , 6 ]], dtype = torch .long ),
467- "prompt_attention_mask" : torch .tensor ([[0 , 0 , 1 , 1 , 1 ]], dtype = torch .long ),
468- }
469- ]
470-
471- GOLDTrainer ._generate_on_policy_for_slices (trainer , slices , [0 ])
472-
473- buffered_inputs = trainer ._buffered_inputs [0 ]
474- assert trainer .vllm_generation .prompts == [[5 , 13 , 6 ]]
475- assert trainer .vllm_generation .sync_calls == 1
476- assert torch .equal (buffered_inputs ["input_ids" ], torch .tensor ([[5 , 13 , 6 , 42 ]], dtype = torch .long ))
477- assert torch .equal (buffered_inputs ["attention_mask" ], torch .tensor ([[1 , 1 , 1 , 1 ]], dtype = torch .long ))
478- assert torch .equal (buffered_inputs ["labels" ], torch .tensor ([[- 100 , - 100 , - 100 , 42 ]], dtype = torch .long ))
479- assert buffered_inputs ["original_prompt_text" ] == ["A <special> B" ]
480- assert buffered_inputs ["original_completion_text" ] == ["C" ]
481- assert trainer ._buffered_text_logs [0 ] == (["A B" ], ["C" ])
482-
483-
484- def test_gold_trainer_init_defaults_vllm_max_model_length_to_max_length (monkeypatch ):
485- captured = {}
486-
487- class DummyStudentModel :
488- def __init__ (self ):
489- self .config = SimpleNamespace (_name_or_path = "student" , vocab_size = 17 )
490- self .generation_config = SimpleNamespace (eos_token_id = 2 )
491- self .name_or_path = "student"
492-
493- class DummyTeacherModel :
494- def __init__ (self ):
495- self .resized_to = None
496-
497- def resize_token_embeddings (self , vocab_size ):
498- self .resized_to = vocab_size
499-
500- class DummyProcessingClass :
501- pad_token_id = 0
502-
503- def fake_sft_init (
504- self ,
505- model ,
506- args = None ,
507- data_collator = None ,
508- train_dataset = None ,
509- eval_dataset = None ,
510- processing_class = None ,
511- compute_metrics = None ,
512- callbacks = None ,
513- optimizers = None ,
514- preprocess_logits_for_metrics = None ,
515- peft_config = None ,
516- ):
517- del data_collator , train_dataset , eval_dataset , compute_metrics , callbacks , optimizers
518- del preprocess_logits_for_metrics , peft_config
519- self .model = model
520- self .args = args
521- self .processing_class = processing_class
522- self .accelerator = SimpleNamespace (
523- device = torch .device ("cpu" ),
524- num_processes = 1 ,
525- prepare_model = lambda module , evaluation_mode = True : module ,
526- )
527- self .is_deepspeed_enabled = False
528- self .is_fsdp_enabled = False
529-
530- class CapturingVLLMGeneration :
531- def __init__ (self , ** kwargs ):
532- captured .update (kwargs )
533-
534- monkeypatch .setattr (gold_trainer_module .SFTTrainer , "__init__" , fake_sft_init )
535- monkeypatch .setattr (gold_trainer_module , "is_vllm_available" , lambda : True )
536- monkeypatch .setattr (gold_trainer_module , "VLLMGeneration" , CapturingVLLMGeneration )
537-
538- args = SimpleNamespace (
539- model_init_kwargs = None ,
540- max_length = 128 ,
541- use_liger_kernel = False ,
542- teacher_model_init_kwargs = None ,
543- use_uld_loss = False ,
544- teacher_tokenizer_name_or_path = None ,
545- teacher_model_revision = None ,
546- disable_dropout = False ,
547- lmbda = 1.0 ,
548- beta = 0.5 ,
549- temperature = 1.0 ,
550- top_p = 1.0 ,
551- seq_kd = False ,
552- num_generations = 1 ,
553- use_transformers_paged = False ,
554- max_completion_length = 16 ,
555- top_k = 0 ,
556- log_completions = False ,
557- log_completions_steps = 100 ,
558- wandb_log_unique_prompts = True ,
559- num_completions_to_print = None ,
560- per_device_train_batch_size = 1 ,
561- gradient_accumulation_steps = 1 ,
562- use_vllm = True ,
563- vllm_mode = "colocate" ,
564- vllm_structured_outputs_regex = None ,
565- vllm_server_base_url = None ,
566- vllm_server_host = "0.0.0.0" ,
567- vllm_server_port = 8001 ,
568- vllm_group_port = 51216 ,
569- vllm_server_timeout = 240.0 ,
570- vllm_tensor_parallel_size = 1 ,
571- vllm_gpu_memory_utilization = 0.2 ,
572- vllm_max_model_length = None ,
573- vllm_enable_sleep_mode = False ,
574- vllm_model_impl = "vllm" ,
575- vllm_sync_frequency = 1 ,
576- )
577-
578- teacher_model = DummyTeacherModel ()
579- GOLDTrainer (
580- model = DummyStudentModel (),
581- teacher_model = teacher_model ,
582- args = args ,
583- data_collator = object (),
584- processing_class = DummyProcessingClass (),
585- )
586-
587- assert teacher_model .resized_to == 17
588- assert captured ["max_model_length" ] == 128
589-
590-
591374def test_alignment_groups_cover_all_tokens (llama_tokenizer , qwen_tokenizer ):
592375 config = build_config ()
593376 loss = ULDLoss (config , student_tokenizer = llama_tokenizer , teacher_tokenizer = qwen_tokenizer )
0 commit comments