@@ -2386,7 +2386,7 @@ def test_batching_continuous_throughput(
2386
2386
assert len (processing_events ) > 0 , "No processing occurred"
2387
2387
2388
2388
# Check that processing happened across multiple threads (indicating concurrent processing)
2389
- thread_ids = set ( event ["thread_id" ] for event in processing_events )
2389
+ thread_ids = { event ["thread_id" ] for event in processing_events } # noqa
2390
2390
assert (
2391
2391
len (thread_ids ) > 1
2392
2392
), f"All processing happened in single thread: { thread_ids } "
@@ -2503,6 +2503,259 @@ def test_batching_configuration_validation(
2503
2503
assert wrapper ._min_batch_size == 1
2504
2504
assert wrapper ._max_batch_size == 2
2505
2505
2506
+ @pytest .mark .parametrize (
2507
+ "wrapper_class" ,
2508
+ [vLLMWrapper , TransformersWrapperMaxTokens ],
2509
+ ids = ["vllm" , "transformers" ],
2510
+ )
2511
+ def test_standardized_generation_parameters (
2512
+ self , wrapper_class , vllm_instance , transformers_instance
2513
+ ):
2514
+ """Test that standardized generation parameters work across both wrappers."""
2515
+ model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2516
+ tokenizer = model .get_tokenizer () if hasattr (model , "get_tokenizer" ) else None
2517
+
2518
+ # Test with standardized parameters
2519
+ wrapper = wrapper_class (
2520
+ model ,
2521
+ tokenizer = tokenizer ,
2522
+ input_mode = "text" ,
2523
+ generate = True ,
2524
+ generate_kwargs = {
2525
+ "max_new_tokens" : 10 , # Standardized name
2526
+ "num_return_sequences" : 1 , # Standardized name
2527
+ "temperature" : 0.7 ,
2528
+ "top_p" : 0.9 ,
2529
+ "top_k" : 50 ,
2530
+ "repetition_penalty" : 1.1 ,
2531
+ "do_sample" : True ,
2532
+ "num_beams" : 1 ,
2533
+ "length_penalty" : 1.0 ,
2534
+ "early_stopping" : False ,
2535
+ "skip_special_tokens" : True ,
2536
+ "logprobs" : True ,
2537
+ },
2538
+ )
2539
+
2540
+ # Test that the wrapper was created successfully
2541
+ assert wrapper is not None
2542
+
2543
+ # Test that the parameters were properly converted
2544
+ if wrapper_class == vLLMWrapper :
2545
+ # Check that vLLM-specific parameters were set
2546
+ assert (
2547
+ wrapper .sampling_params .max_tokens == 10
2548
+ ) # max_new_tokens -> max_tokens
2549
+ assert wrapper .sampling_params .n == 1 # num_return_sequences -> n
2550
+ assert wrapper .sampling_params .temperature == 0.7
2551
+ assert wrapper .sampling_params .top_p == 0.9
2552
+ assert wrapper .sampling_params .top_k == 50
2553
+ assert wrapper .sampling_params .repetition_penalty == 1.1
2554
+ assert wrapper .sampling_params .best_of == 1 # num_beams -> best_of
2555
+ # do_sample=True means we use sampling (temperature > 0), not greedy decoding
2556
+ assert wrapper .sampling_params .temperature > 0
2557
+ else :
2558
+ # Check that Transformers parameters were set
2559
+ assert wrapper .generate_kwargs ["max_new_tokens" ] == 10
2560
+ assert wrapper .generate_kwargs ["num_return_sequences" ] == 1
2561
+ assert wrapper .generate_kwargs ["temperature" ] == 0.7
2562
+ assert wrapper .generate_kwargs ["top_p" ] == 0.9
2563
+ assert wrapper .generate_kwargs ["top_k" ] == 50
2564
+ assert wrapper .generate_kwargs ["repetition_penalty" ] == 1.1
2565
+ assert wrapper .generate_kwargs ["do_sample" ] is True
2566
+
2567
+ @pytest .mark .parametrize (
2568
+ "wrapper_class" ,
2569
+ [vLLMWrapper , TransformersWrapperMaxTokens ],
2570
+ ids = ["vllm" , "transformers" ],
2571
+ )
2572
+ def test_legacy_parameter_names (
2573
+ self , wrapper_class , vllm_instance , transformers_instance
2574
+ ):
2575
+ """Test that legacy parameter names are automatically converted to standardized names."""
2576
+ model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2577
+ tokenizer = model .get_tokenizer () if hasattr (model , "get_tokenizer" ) else None
2578
+
2579
+ # Test with legacy parameter names
2580
+ wrapper = wrapper_class (
2581
+ model ,
2582
+ tokenizer = tokenizer ,
2583
+ input_mode = "text" ,
2584
+ generate = True ,
2585
+ generate_kwargs = {
2586
+ "max_tokens" : 10 , # Legacy vLLM name
2587
+ "n" : 1 , # Legacy vLLM name
2588
+ "temperature" : 0.7 ,
2589
+ },
2590
+ )
2591
+
2592
+ # Test that the wrapper was created successfully
2593
+ assert wrapper is not None
2594
+
2595
+ # Test that the parameters were properly converted
2596
+ if wrapper_class == vLLMWrapper :
2597
+ # Check that legacy names were converted to vLLM format
2598
+ assert (
2599
+ wrapper .sampling_params .max_tokens == 10
2600
+ ) # max_tokens -> max_tokens (no change)
2601
+ assert wrapper .sampling_params .n == 1 # n -> n (no change)
2602
+ assert wrapper .sampling_params .temperature == 0.7
2603
+ else :
2604
+ # Check that legacy names were converted to Transformers format
2605
+ assert (
2606
+ wrapper .generate_kwargs ["max_new_tokens" ] == 10
2607
+ ) # max_tokens -> max_new_tokens
2608
+ assert (
2609
+ wrapper .generate_kwargs ["num_return_sequences" ] == 1
2610
+ ) # n -> num_return_sequences
2611
+ assert wrapper .generate_kwargs ["temperature" ] == 0.7
2612
+
2613
+ @pytest .mark .parametrize (
2614
+ "wrapper_class" ,
2615
+ [vLLMWrapper , TransformersWrapperMaxTokens ],
2616
+ ids = ["vllm" , "transformers" ],
2617
+ )
2618
+ def test_parameter_conflict_resolution (
2619
+ self , wrapper_class , vllm_instance , transformers_instance
2620
+ ):
2621
+ """Test that parameter conflicts are resolved correctly when both legacy and standardized names are used."""
2622
+ model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2623
+ tokenizer = model .get_tokenizer () if hasattr (model , "get_tokenizer" ) else None
2624
+
2625
+ # Test with conflicting parameters - legacy name should win
2626
+ wrapper = wrapper_class (
2627
+ model ,
2628
+ tokenizer = tokenizer ,
2629
+ input_mode = "text" ,
2630
+ generate = True ,
2631
+ generate_kwargs = {
2632
+ "max_tokens" : 20 , # Legacy name
2633
+ "max_new_tokens" : 10 , # Standardized name
2634
+ "n" : 2 , # Legacy name
2635
+ "num_return_sequences" : 1 , # Standardized name
2636
+ "temperature" : 0.7 ,
2637
+ },
2638
+ )
2639
+
2640
+ # Test that the wrapper was created successfully
2641
+ assert wrapper is not None
2642
+
2643
+ # Test that the parameters were properly resolved
2644
+ if wrapper_class == vLLMWrapper :
2645
+ # Legacy names should win
2646
+ assert wrapper .sampling_params .max_tokens == 20 # max_tokens wins
2647
+ assert wrapper .sampling_params .n == 2 # n wins
2648
+ assert wrapper .sampling_params .temperature == 0.7
2649
+ else :
2650
+ # Legacy names should be converted to standardized names
2651
+ assert (
2652
+ wrapper .generate_kwargs ["max_new_tokens" ] == 20
2653
+ ) # max_tokens -> max_new_tokens
2654
+ assert (
2655
+ wrapper .generate_kwargs ["num_return_sequences" ] == 2
2656
+ ) # n -> num_return_sequences
2657
+ assert wrapper .generate_kwargs ["temperature" ] == 0.7
2658
+
2659
+ @pytest .mark .parametrize (
2660
+ "wrapper_class" ,
2661
+ [vLLMWrapper , TransformersWrapperMaxTokens ],
2662
+ ids = ["vllm" , "transformers" ],
2663
+ )
2664
+ def test_parameter_validation (
2665
+ self , wrapper_class , vllm_instance , transformers_instance
2666
+ ):
2667
+ """Test that parameter validation works correctly."""
2668
+ model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2669
+ tokenizer = model .get_tokenizer () if hasattr (model , "get_tokenizer" ) else None
2670
+
2671
+ # Test invalid temperature
2672
+ with pytest .raises (ValueError , match = "Temperature must be non-negative" ):
2673
+ wrapper_class (
2674
+ model ,
2675
+ tokenizer = tokenizer ,
2676
+ input_mode = "text" ,
2677
+ generate = True ,
2678
+ generate_kwargs = {"temperature" : - 0.1 },
2679
+ )
2680
+
2681
+ # Test invalid top_p
2682
+ with pytest .raises (ValueError , match = "top_p must be between 0 and 1" ):
2683
+ wrapper_class (
2684
+ model ,
2685
+ tokenizer = tokenizer ,
2686
+ input_mode = "text" ,
2687
+ generate = True ,
2688
+ generate_kwargs = {"top_p" : 1.5 },
2689
+ )
2690
+
2691
+ # Test invalid top_k
2692
+ with pytest .raises (ValueError , match = "top_k must be positive" ):
2693
+ wrapper_class (
2694
+ model ,
2695
+ tokenizer = tokenizer ,
2696
+ input_mode = "text" ,
2697
+ generate = True ,
2698
+ generate_kwargs = {"top_k" : 0 },
2699
+ )
2700
+
2701
+ # Test invalid repetition_penalty
2702
+ with pytest .raises (ValueError , match = "repetition_penalty must be positive" ):
2703
+ wrapper_class (
2704
+ model ,
2705
+ tokenizer = tokenizer ,
2706
+ input_mode = "text" ,
2707
+ generate = True ,
2708
+ generate_kwargs = {"repetition_penalty" : 0.5 },
2709
+ )
2710
+
2711
+ # Test conflicting do_sample and temperature
2712
+ with pytest .raises (
2713
+ ValueError , match = "When do_sample=False.*temperature must be 0"
2714
+ ):
2715
+ wrapper_class (
2716
+ model ,
2717
+ tokenizer = tokenizer ,
2718
+ input_mode = "text" ,
2719
+ generate = True ,
2720
+ generate_kwargs = {"do_sample" : False , "temperature" : 0.7 },
2721
+ )
2722
+
2723
+ @pytest .mark .parametrize (
2724
+ "wrapper_class" ,
2725
+ [vLLMWrapper , TransformersWrapperMaxTokens ],
2726
+ ids = ["vllm" , "transformers" ],
2727
+ )
2728
+ def test_parameter_conflict_errors (
2729
+ self , wrapper_class , vllm_instance , transformers_instance
2730
+ ):
2731
+ """Test that parameter conflicts are properly detected."""
2732
+ model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2733
+ tokenizer = model .get_tokenizer () if hasattr (model , "get_tokenizer" ) else None
2734
+
2735
+ # Test conflicting max_tokens and max_new_tokens
2736
+ with pytest .raises (
2737
+ ValueError , match = "Cannot specify both 'max_tokens'.*'max_new_tokens'"
2738
+ ):
2739
+ wrapper_class (
2740
+ model ,
2741
+ tokenizer = tokenizer ,
2742
+ input_mode = "text" ,
2743
+ generate = True ,
2744
+ generate_kwargs = {"max_tokens" : 10 , "max_new_tokens" : 20 },
2745
+ )
2746
+
2747
+ # Test conflicting n and num_return_sequences
2748
+ with pytest .raises (
2749
+ ValueError , match = "Cannot specify both 'n'.*'num_return_sequences'"
2750
+ ):
2751
+ wrapper_class (
2752
+ model ,
2753
+ tokenizer = tokenizer ,
2754
+ input_mode = "text" ,
2755
+ generate = True ,
2756
+ generate_kwargs = {"n" : 1 , "num_return_sequences" : 2 },
2757
+ )
2758
+
2506
2759
2507
2760
if __name__ == "__main__" :
2508
2761
args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments