Skip to content

Commit 978424e

Browse files
authored
[Feature] Common kwargs for generate across vLLM and Transformers (#3107)
1 parent 7752561 commit 978424e

File tree

4 files changed

+672
-26
lines changed

4 files changed

+672
-26
lines changed

test/llm/test_wrapper.py

Lines changed: 254 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2386,7 +2386,7 @@ def test_batching_continuous_throughput(
23862386
assert len(processing_events) > 0, "No processing occurred"
23872387

23882388
# Check that processing happened across multiple threads (indicating concurrent processing)
2389-
thread_ids = set(event["thread_id"] for event in processing_events)
2389+
thread_ids = {event["thread_id"] for event in processing_events} # noqa
23902390
assert (
23912391
len(thread_ids) > 1
23922392
), f"All processing happened in single thread: {thread_ids}"
@@ -2503,6 +2503,259 @@ def test_batching_configuration_validation(
25032503
assert wrapper._min_batch_size == 1
25042504
assert wrapper._max_batch_size == 2
25052505

2506+
@pytest.mark.parametrize(
2507+
"wrapper_class",
2508+
[vLLMWrapper, TransformersWrapperMaxTokens],
2509+
ids=["vllm", "transformers"],
2510+
)
2511+
def test_standardized_generation_parameters(
2512+
self, wrapper_class, vllm_instance, transformers_instance
2513+
):
2514+
"""Test that standardized generation parameters work across both wrappers."""
2515+
model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2516+
tokenizer = model.get_tokenizer() if hasattr(model, "get_tokenizer") else None
2517+
2518+
# Test with standardized parameters
2519+
wrapper = wrapper_class(
2520+
model,
2521+
tokenizer=tokenizer,
2522+
input_mode="text",
2523+
generate=True,
2524+
generate_kwargs={
2525+
"max_new_tokens": 10, # Standardized name
2526+
"num_return_sequences": 1, # Standardized name
2527+
"temperature": 0.7,
2528+
"top_p": 0.9,
2529+
"top_k": 50,
2530+
"repetition_penalty": 1.1,
2531+
"do_sample": True,
2532+
"num_beams": 1,
2533+
"length_penalty": 1.0,
2534+
"early_stopping": False,
2535+
"skip_special_tokens": True,
2536+
"logprobs": True,
2537+
},
2538+
)
2539+
2540+
# Test that the wrapper was created successfully
2541+
assert wrapper is not None
2542+
2543+
# Test that the parameters were properly converted
2544+
if wrapper_class == vLLMWrapper:
2545+
# Check that vLLM-specific parameters were set
2546+
assert (
2547+
wrapper.sampling_params.max_tokens == 10
2548+
) # max_new_tokens -> max_tokens
2549+
assert wrapper.sampling_params.n == 1 # num_return_sequences -> n
2550+
assert wrapper.sampling_params.temperature == 0.7
2551+
assert wrapper.sampling_params.top_p == 0.9
2552+
assert wrapper.sampling_params.top_k == 50
2553+
assert wrapper.sampling_params.repetition_penalty == 1.1
2554+
assert wrapper.sampling_params.best_of == 1 # num_beams -> best_of
2555+
# do_sample=True means we use sampling (temperature > 0), not greedy decoding
2556+
assert wrapper.sampling_params.temperature > 0
2557+
else:
2558+
# Check that Transformers parameters were set
2559+
assert wrapper.generate_kwargs["max_new_tokens"] == 10
2560+
assert wrapper.generate_kwargs["num_return_sequences"] == 1
2561+
assert wrapper.generate_kwargs["temperature"] == 0.7
2562+
assert wrapper.generate_kwargs["top_p"] == 0.9
2563+
assert wrapper.generate_kwargs["top_k"] == 50
2564+
assert wrapper.generate_kwargs["repetition_penalty"] == 1.1
2565+
assert wrapper.generate_kwargs["do_sample"] is True
2566+
2567+
@pytest.mark.parametrize(
2568+
"wrapper_class",
2569+
[vLLMWrapper, TransformersWrapperMaxTokens],
2570+
ids=["vllm", "transformers"],
2571+
)
2572+
def test_legacy_parameter_names(
2573+
self, wrapper_class, vllm_instance, transformers_instance
2574+
):
2575+
"""Test that legacy parameter names are automatically converted to standardized names."""
2576+
model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2577+
tokenizer = model.get_tokenizer() if hasattr(model, "get_tokenizer") else None
2578+
2579+
# Test with legacy parameter names
2580+
wrapper = wrapper_class(
2581+
model,
2582+
tokenizer=tokenizer,
2583+
input_mode="text",
2584+
generate=True,
2585+
generate_kwargs={
2586+
"max_tokens": 10, # Legacy vLLM name
2587+
"n": 1, # Legacy vLLM name
2588+
"temperature": 0.7,
2589+
},
2590+
)
2591+
2592+
# Test that the wrapper was created successfully
2593+
assert wrapper is not None
2594+
2595+
# Test that the parameters were properly converted
2596+
if wrapper_class == vLLMWrapper:
2597+
# Check that legacy names were converted to vLLM format
2598+
assert (
2599+
wrapper.sampling_params.max_tokens == 10
2600+
) # max_tokens -> max_tokens (no change)
2601+
assert wrapper.sampling_params.n == 1 # n -> n (no change)
2602+
assert wrapper.sampling_params.temperature == 0.7
2603+
else:
2604+
# Check that legacy names were converted to Transformers format
2605+
assert (
2606+
wrapper.generate_kwargs["max_new_tokens"] == 10
2607+
) # max_tokens -> max_new_tokens
2608+
assert (
2609+
wrapper.generate_kwargs["num_return_sequences"] == 1
2610+
) # n -> num_return_sequences
2611+
assert wrapper.generate_kwargs["temperature"] == 0.7
2612+
2613+
@pytest.mark.parametrize(
2614+
"wrapper_class",
2615+
[vLLMWrapper, TransformersWrapperMaxTokens],
2616+
ids=["vllm", "transformers"],
2617+
)
2618+
def test_parameter_conflict_resolution(
2619+
self, wrapper_class, vllm_instance, transformers_instance
2620+
):
2621+
"""Test that parameter conflicts are resolved correctly when both legacy and standardized names are used."""
2622+
model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2623+
tokenizer = model.get_tokenizer() if hasattr(model, "get_tokenizer") else None
2624+
2625+
# Test with conflicting parameters - legacy name should win
2626+
wrapper = wrapper_class(
2627+
model,
2628+
tokenizer=tokenizer,
2629+
input_mode="text",
2630+
generate=True,
2631+
generate_kwargs={
2632+
"max_tokens": 20, # Legacy name
2633+
"max_new_tokens": 10, # Standardized name
2634+
"n": 2, # Legacy name
2635+
"num_return_sequences": 1, # Standardized name
2636+
"temperature": 0.7,
2637+
},
2638+
)
2639+
2640+
# Test that the wrapper was created successfully
2641+
assert wrapper is not None
2642+
2643+
# Test that the parameters were properly resolved
2644+
if wrapper_class == vLLMWrapper:
2645+
# Legacy names should win
2646+
assert wrapper.sampling_params.max_tokens == 20 # max_tokens wins
2647+
assert wrapper.sampling_params.n == 2 # n wins
2648+
assert wrapper.sampling_params.temperature == 0.7
2649+
else:
2650+
# Legacy names should be converted to standardized names
2651+
assert (
2652+
wrapper.generate_kwargs["max_new_tokens"] == 20
2653+
) # max_tokens -> max_new_tokens
2654+
assert (
2655+
wrapper.generate_kwargs["num_return_sequences"] == 2
2656+
) # n -> num_return_sequences
2657+
assert wrapper.generate_kwargs["temperature"] == 0.7
2658+
2659+
@pytest.mark.parametrize(
2660+
"wrapper_class",
2661+
[vLLMWrapper, TransformersWrapperMaxTokens],
2662+
ids=["vllm", "transformers"],
2663+
)
2664+
def test_parameter_validation(
2665+
self, wrapper_class, vllm_instance, transformers_instance
2666+
):
2667+
"""Test that parameter validation works correctly."""
2668+
model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2669+
tokenizer = model.get_tokenizer() if hasattr(model, "get_tokenizer") else None
2670+
2671+
# Test invalid temperature
2672+
with pytest.raises(ValueError, match="Temperature must be non-negative"):
2673+
wrapper_class(
2674+
model,
2675+
tokenizer=tokenizer,
2676+
input_mode="text",
2677+
generate=True,
2678+
generate_kwargs={"temperature": -0.1},
2679+
)
2680+
2681+
# Test invalid top_p
2682+
with pytest.raises(ValueError, match="top_p must be between 0 and 1"):
2683+
wrapper_class(
2684+
model,
2685+
tokenizer=tokenizer,
2686+
input_mode="text",
2687+
generate=True,
2688+
generate_kwargs={"top_p": 1.5},
2689+
)
2690+
2691+
# Test invalid top_k
2692+
with pytest.raises(ValueError, match="top_k must be positive"):
2693+
wrapper_class(
2694+
model,
2695+
tokenizer=tokenizer,
2696+
input_mode="text",
2697+
generate=True,
2698+
generate_kwargs={"top_k": 0},
2699+
)
2700+
2701+
# Test invalid repetition_penalty
2702+
with pytest.raises(ValueError, match="repetition_penalty must be positive"):
2703+
wrapper_class(
2704+
model,
2705+
tokenizer=tokenizer,
2706+
input_mode="text",
2707+
generate=True,
2708+
generate_kwargs={"repetition_penalty": 0.5},
2709+
)
2710+
2711+
# Test conflicting do_sample and temperature
2712+
with pytest.raises(
2713+
ValueError, match="When do_sample=False.*temperature must be 0"
2714+
):
2715+
wrapper_class(
2716+
model,
2717+
tokenizer=tokenizer,
2718+
input_mode="text",
2719+
generate=True,
2720+
generate_kwargs={"do_sample": False, "temperature": 0.7},
2721+
)
2722+
2723+
@pytest.mark.parametrize(
2724+
"wrapper_class",
2725+
[vLLMWrapper, TransformersWrapperMaxTokens],
2726+
ids=["vllm", "transformers"],
2727+
)
2728+
def test_parameter_conflict_errors(
2729+
self, wrapper_class, vllm_instance, transformers_instance
2730+
):
2731+
"""Test that parameter conflicts are properly detected."""
2732+
model = vllm_instance if wrapper_class == vLLMWrapper else transformers_instance
2733+
tokenizer = model.get_tokenizer() if hasattr(model, "get_tokenizer") else None
2734+
2735+
# Test conflicting max_tokens and max_new_tokens
2736+
with pytest.raises(
2737+
ValueError, match="Cannot specify both 'max_tokens'.*'max_new_tokens'"
2738+
):
2739+
wrapper_class(
2740+
model,
2741+
tokenizer=tokenizer,
2742+
input_mode="text",
2743+
generate=True,
2744+
generate_kwargs={"max_tokens": 10, "max_new_tokens": 20},
2745+
)
2746+
2747+
# Test conflicting n and num_return_sequences
2748+
with pytest.raises(
2749+
ValueError, match="Cannot specify both 'n'.*'num_return_sequences'"
2750+
):
2751+
wrapper_class(
2752+
model,
2753+
tokenizer=tokenizer,
2754+
input_mode="text",
2755+
generate=True,
2756+
generate_kwargs={"n": 1, "num_return_sequences": 2},
2757+
)
2758+
25062759

25072760
if __name__ == "__main__":
25082761
args, unknown = argparse.ArgumentParser().parse_known_args()

0 commit comments

Comments
 (0)