|
8 | 8 | import importlib.util
|
9 | 9 |
|
10 | 10 | import os
|
| 11 | +import time |
11 | 12 | from functools import partial
|
12 | 13 |
|
13 | 14 | import pytest
|
@@ -72,7 +73,11 @@ def vllm_instance():
|
72 | 73 | assert os.environ.get("VLLM_USE_V1") == "0"
|
73 | 74 |
|
74 | 75 | try:
|
75 |
| - model = LLM("Qwen/Qwen2.5-0.5B") |
| 76 | + model = LLM( |
| 77 | + "Qwen/Qwen2.5-0.5B", |
| 78 | + max_num_batched_tokens=32768, # Match max_model_len |
| 79 | + max_model_len=32768, |
| 80 | + ) |
76 | 81 | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
77 | 82 | tokenizer.pad_token = tokenizer.eos_token
|
78 | 83 | return model, tokenizer
|
@@ -717,6 +722,179 @@ def test_generate_false_without_log_probs(
|
717 | 722 | return_log_probs=False,
|
718 | 723 | )
|
719 | 724 |
|
| 725 | + # ================================================ |
| 726 | + # Batching Tests |
| 727 | + # ================================================ |
| 728 | + |
| 729 | + @pytest.mark.parametrize( |
| 730 | + "wrapper_class", |
| 731 | + [vLLMWrapper, TransformersWrapperMaxTokens], |
| 732 | + ids=["vllm", "transformers"], |
| 733 | + ) |
| 734 | + def test_batching(self, wrapper_class, vllm_instance, transformers_instance): |
| 735 | + from concurrent.futures import ThreadPoolExecutor, wait |
| 736 | + |
| 737 | + # Handle the case where vLLM is not available |
| 738 | + if wrapper_class == vLLMWrapper: |
| 739 | + try: |
| 740 | + model, tokenizer = vllm_instance |
| 741 | + except Exception as e: |
| 742 | + if "vLLM compatibility issue" in str(e): |
| 743 | + pytest.skip("vLLM not available due to compatibility issues") |
| 744 | + raise |
| 745 | + else: |
| 746 | + model, tokenizer = transformers_instance |
| 747 | + |
| 748 | + wrapper = wrapper_class( |
| 749 | + model, |
| 750 | + tokenizer=tokenizer, |
| 751 | + input_mode="text", |
| 752 | + generate=True, |
| 753 | + return_log_probs=True, |
| 754 | + batch_size=4, |
| 755 | + ) |
| 756 | + # Create 2 threads and send inputs |
| 757 | + inputs = [ |
| 758 | + TensorDict( |
| 759 | + text=Text(prompt=[f"Question {i}?", f"Question {i+2}?"]), |
| 760 | + batch_size=(2,), |
| 761 | + ) |
| 762 | + for i in range(2) |
| 763 | + ] |
| 764 | + pool = ThreadPoolExecutor(max_workers=2) |
| 765 | + try: |
| 766 | + futures = [pool.submit(wrapper, input) for input in inputs] |
| 767 | + wait(futures) |
| 768 | + finally: |
| 769 | + pool.shutdown(wait=False, cancel_futures=True) |
| 770 | + |
| 771 | + @pytest.mark.parametrize( |
| 772 | + "wrapper_class", |
| 773 | + [vLLMWrapper, TransformersWrapperMaxTokens], |
| 774 | + ids=["vllm", "transformers"], |
| 775 | + ) |
| 776 | + def test_batching_uneven(self, wrapper_class, vllm_instance, transformers_instance): |
| 777 | + from concurrent.futures import ThreadPoolExecutor, wait |
| 778 | + |
| 779 | + if wrapper_class == vLLMWrapper: |
| 780 | + model, tokenizer = vllm_instance |
| 781 | + else: |
| 782 | + model, tokenizer = transformers_instance |
| 783 | + wrapper = wrapper_class( |
| 784 | + model, |
| 785 | + tokenizer=tokenizer, |
| 786 | + input_mode="text", |
| 787 | + generate=True, |
| 788 | + return_log_probs=True, |
| 789 | + batch_size=5, |
| 790 | + batching_timeout=5, # Increased timeout for CI environments |
| 791 | + ) |
| 792 | + inputs = [ |
| 793 | + TensorDict(text=Text(prompt=["Question 1?"]), batch_size=(1,)), |
| 794 | + TensorDict( |
| 795 | + text=Text(prompt=["Question 2?", "Question 3?", "Question 4?"]), |
| 796 | + batch_size=(3,), |
| 797 | + ), |
| 798 | + TensorDict( |
| 799 | + text=Text(prompt=["Question 5?", "Question 6?"]), batch_size=(2,) |
| 800 | + ), |
| 801 | + ] |
| 802 | + pool = ThreadPoolExecutor(max_workers=3) |
| 803 | + try: |
| 804 | + futures = [] |
| 805 | + for input in inputs: |
| 806 | + futures.append(pool.submit(wrapper, input)) |
| 807 | + time.sleep(0.05) # Increased delay for more reliable timing |
| 808 | + |
| 809 | + # Wait for first two futures with longer timeout |
| 810 | + wait(futures[:2], timeout=3) |
| 811 | + |
| 812 | + # Check results with more flexible assertions |
| 813 | + result0 = futures[0].result() |
| 814 | + result1 = futures[1].result() |
| 815 | + |
| 816 | + assert result0["text"].prompt == ["Question 1?"] |
| 817 | + assert result1["text"].prompt == [ |
| 818 | + "Question 2?", |
| 819 | + "Question 3?", |
| 820 | + "Question 4?", |
| 821 | + ] |
| 822 | + |
| 823 | + # The third future may or may not be done depending on timing |
| 824 | + # Wait for it with a reasonable timeout |
| 825 | + wait(futures[2:], timeout=10) |
| 826 | + if not futures[2].done(): |
| 827 | + raise RuntimeError("Third future not done") |
| 828 | + result2 = futures[2].result() |
| 829 | + assert result2["text"].prompt == ["Question 5?", "Question 6?"] |
| 830 | + finally: |
| 831 | + pool.shutdown(wait=False, cancel_futures=True) |
| 832 | + |
| 833 | + @pytest.mark.parametrize( |
| 834 | + "wrapper_class", |
| 835 | + [vLLMWrapper, TransformersWrapperMaxTokens], |
| 836 | + ids=["vllm", "transformers"], |
| 837 | + ) |
| 838 | + def test_batching_cleanup( |
| 839 | + self, wrapper_class, vllm_instance, transformers_instance |
| 840 | + ): |
| 841 | + """Test batching cleanup functionality.""" |
| 842 | + if wrapper_class == vLLMWrapper: |
| 843 | + model, tokenizer = vllm_instance |
| 844 | + else: |
| 845 | + model, tokenizer = transformers_instance |
| 846 | + |
| 847 | + wrapper = wrapper_class( |
| 848 | + model, |
| 849 | + tokenizer=tokenizer, |
| 850 | + input_mode="text", |
| 851 | + generate=True, |
| 852 | + return_log_probs=True, |
| 853 | + batch_size=3, |
| 854 | + ) |
| 855 | + |
| 856 | + # Check initial state |
| 857 | + state = wrapper.get_batching_state() |
| 858 | + assert state["batching_enabled"] is True |
| 859 | + assert state["batch_size"] == 3 |
| 860 | + assert state["queue_size"] == 0 |
| 861 | + assert state["pending_futures"] == 0 |
| 862 | + |
| 863 | + # Add some inputs to the queue |
| 864 | + input1 = TensorDict(text=Text(prompt=["Test 1"]), batch_size=(1,)) |
| 865 | + input2 = TensorDict(text=Text(prompt=["Test 2"]), batch_size=(1,)) |
| 866 | + |
| 867 | + # Submit inputs (they won't be processed immediately due to batch size) |
| 868 | + from concurrent.futures import ThreadPoolExecutor |
| 869 | + |
| 870 | + pool = ThreadPoolExecutor(max_workers=1) |
| 871 | + try: |
| 872 | + future1 = pool.submit(wrapper, input1) |
| 873 | + future2 = pool.submit(wrapper, input2) |
| 874 | + |
| 875 | + # Check state after adding inputs |
| 876 | + state = wrapper.get_batching_state() |
| 877 | + assert state["queue_size"] >= 0 # May be 0 if processed immediately |
| 878 | + assert state["pending_futures"] >= 0 |
| 879 | + |
| 880 | + # Clean up |
| 881 | + wrapper.cleanup_batching() |
| 882 | + |
| 883 | + # Check state after cleanup |
| 884 | + state = wrapper.get_batching_state() |
| 885 | + assert state["queue_size"] == 0 |
| 886 | + assert state["pending_futures"] == 0 |
| 887 | + |
| 888 | + # Wait for futures to complete or fail |
| 889 | + try: |
| 890 | + future1.result(timeout=5) |
| 891 | + future2.result(timeout=5) |
| 892 | + except Exception: |
| 893 | + # Futures may fail after cleanup, which is expected |
| 894 | + pass |
| 895 | + finally: |
| 896 | + pool.shutdown(wait=False, cancel_futures=True) |
| 897 | + |
720 | 898 | # ================================================
|
721 | 899 | # Batch Size Tests
|
722 | 900 | # ================================================
|
|
0 commit comments