diff --git a/ding/utils/data/rlhf_online_dataset.py b/ding/utils/data/rlhf_online_dataset.py index d307f09a32..0386fca534 100644 --- a/ding/utils/data/rlhf_online_dataset.py +++ b/ding/utils/data/rlhf_online_dataset.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union, Callable, Iterable +from typing import Any, Dict, Union, Callable, Iterable, List from tqdm import tqdm from torch.utils.data import Dataset from torch.distributed import get_rank @@ -17,6 +17,7 @@ def __init__( dataset: Iterable[Dict], tokenizer: AutoTokenizer, input_key: str = "input", + extra_input_keys: List[str] = [], apply_chat_template: bool = False, input_template: str = None, ) -> None: @@ -33,18 +34,29 @@ def __init__( super().__init__() self.tokenizer = tokenizer self.input_template = input_template + self.extra_input_keys = extra_input_keys if apply_chat_template: apply_chat_template = self.tokenizer.apply_chat_template self.prompts = [] + for key in extra_input_keys: + setattr(self, key, []) try: rank = get_rank() except ValueError: # not initialized yet, which is the case in unit test rank = 0 for data in tqdm(dataset, desc="Preprocessing data", disable=not rank == 0): - prompt = self._preprocess_data(data, input_template, input_key, apply_chat_template) - self.prompts.append(prompt) + processed_data = self._preprocess_data( + data, input_template, input_key, extra_input_keys, apply_chat_template + ) + self.prompts.append(processed_data['prompt']) + #maybe can be imporved later + for key in extra_input_keys: + getattr(self, key).append(processed_data[key]) + # self.prompts=np.array(self.prompts) + # for key in extra_input_keys: + # setattr(self, key, np.array(getattr(self,key))) def __len__(self) -> int: """ @@ -56,6 +68,7 @@ def __len__(self) -> int: return len(self.prompts) def __getitem__(self, idx: int) -> str: + #can be improved later for list indexing instead of single indexing """ Overview: Get the item at the given index. @@ -64,13 +77,19 @@ def __getitem__(self, idx: int) -> str: Returns: - item (str): The item at the given index. """ - return self.prompts[idx] + # extra inputs: usually image, video, audio, etc. + if self.extra_input_keys: + extra_inputs = {key: getattr(self, key)[idx] for key in self.extra_input_keys} + else: + extra_inputs = {} + return {"prompt": self.prompts[idx], "multi_modal_data": {**extra_inputs}} def _preprocess_data( self, data: Dict[str, Any], input_template: str = None, input_key: str = "input", + extra_input_keys: List[str] = [], apply_chat_template: Union[bool, Callable] = False, ) -> str: """ @@ -86,6 +105,10 @@ def _preprocess_data( Returns: - prompt (str): The formatted prompt. """ + if extra_input_keys: + extra_inputs = {key: data[key] for key in extra_input_keys} + else: + extra_inputs = {} if apply_chat_template: chat = data[input_key] if isinstance(chat, str): @@ -96,4 +119,4 @@ def _preprocess_data( prompt = data[input_key] if input_template: prompt = input_template.format(prompt) - return prompt + return {"prompt": prompt, **extra_inputs} diff --git a/ding/utils/data/tests/test_rlhf_online_dataset.py b/ding/utils/data/tests/test_rlhf_online_dataset.py index cba9e7947c..1e12a777dd 100644 --- a/ding/utils/data/tests/test_rlhf_online_dataset.py +++ b/ding/utils/data/tests/test_rlhf_online_dataset.py @@ -1,27 +1,39 @@ import pytest from datasets import load_dataset -from transformers import AutoTokenizer from ding.utils.data import OnlineRLDataset +from transformers import AutoTokenizer +IMG_CONTEXT_TOKEN = '' +IMG_START_TOKEN = '' +IMG_END_TOKEN = '' +IMG_CONTEXT_NUM = 10 # user-defined number of image patches in the context @pytest.fixture def dataset(): # Load the dataset - hf_dataset = load_dataset("cat-searcher/minif2f-lean4")['validation'] + hf_dataset = load_dataset("MMInstruction/VL-RewardBench", split='test') + hf_dataset0 = hf_dataset.map( + lambda x: { + "query": f"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * IMG_CONTEXT_NUM}{IMG_END_TOKEN}\n{x['query']}", + "image": x["image"], + } + ) + # shuffle the dataset + hf_dataset = hf_dataset0.shuffle(seed=42) print(hf_dataset) return hf_dataset @pytest.fixture def tokenizer(): - return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B") + return AutoTokenizer.from_pretrained("OpenGVLab/InternVL2_5-4B") @pytest.mark.unittest def test_onlinerl_dataset_initialization(dataset, tokenizer): # Initialize OnlineRLDataset online_rl_dataset = OnlineRLDataset( - dataset=dataset, tokenizer=tokenizer, input_key="formal_statement", apply_chat_template=True + dataset=dataset, tokenizer=tokenizer, input_key="query", extra_input_keys=["image"], apply_chat_template=True ) # Check if the dataset is initialized correctly assert len(online_rl_dataset) == len(dataset) @@ -31,9 +43,12 @@ def test_onlinerl_dataset_initialization(dataset, tokenizer): def test_onlinerl_dataset_getitem(dataset, tokenizer): # Initialize OnlineRLDataset online_rl_dataset = OnlineRLDataset( - dataset=dataset, tokenizer=tokenizer, input_key="formal_statement", apply_chat_template=True + dataset=dataset, tokenizer=tokenizer, input_key="query", extra_input_keys=["image"], apply_chat_template=True ) # Check if __getitem__ returns the expected formatted prompt item = online_rl_dataset[0] print(item) - assert isinstance(item, str) + assert "prompt" in item + assert "multi_modal_data" in item + assert "image" in item['multi_modal_data'] + assert isinstance(item['prompt'], str) diff --git a/ding/worker/collector/tests/test_vllm_collector.py b/ding/worker/collector/tests/test_vllm_collector.py new file mode 100644 index 0000000000..ca210bdae7 --- /dev/null +++ b/ding/worker/collector/tests/test_vllm_collector.py @@ -0,0 +1,214 @@ +from typing import List, Tuple, Optional +from ding.worker.collector.vllm_collector import HuggingFaceModelGenerator, get_free_gpus +from vllm.assets.image import ImageAsset +from enum import Enum +from datasets import load_dataset +import asyncio +from PIL import Image +import os +import concurrent.futures +import pytest + + +class Modality(Enum): + IMAGE = "image" + TEXT = "text" + VIDEO = "video" + + +def chunk_list(original_list: List, t: int): + # chunk a list into sub_lists + # base length of sublists + base_length = len(original_list) // t + # remaind length of some sub_lists + remainder = len(original_list) % t + new_list = [] + index = 0 + for i in range(t): + if i < remainder: + sublist_length = base_length + 1 + else: + sublist_length = base_length + new_list.append(original_list[index:index + sublist_length]) + index += sublist_length + return new_list + + +def get_prompts_qwen(questions: list, modality: Modality) -> Tuple[List[str], Optional[List[int]]]: + if modality == Modality.IMAGE: + placeholder = "<|image_pad|>" + elif modality == Modality.VIDEO: + placeholder = "<|video_pad|>" + else: + msg = f"Modality {modality} is not supported." + raise ValueError(msg) + + prompts = [ + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) for question in questions + ] + stop_token_ids = None + return prompts, stop_token_ids + + +def get_multi_modal_input(modality: Modality, filenames: list, questions: list) -> dict: + """ + return { + "data": image or video, + "question": question, + } + """ + if modality == Modality.IMAGE: + # Input image and question + ret = {'data': [], 'question': []} + for filename, question in zip(filenames, questions): + if isinstance(filename, str): + image = ImageAsset(filename) \ + .pil_image.convert("RGB") + #img_question = "What is the content of this image?" + elif isinstance(filename, Image.Image): + image = filename + else: + raise ValueError(f"Unsupported type in filenames: {type(filename)}") + img_question = question + ret["data"].append(image) + ret["question"].append(img_question) + else: + msg = f"Modality {modality} is not supported." + raise ValueError(msg) + return ret + + +# -----------------testing single gpu vllm_actor -------------------------------- +async def single_main(model_path: str, gpu: list, temperature: float, modality: str, prompts: list, data: list): + # note that HFModelGenerator has a parameter + # "mm_processor_kwargs" set to align with the settings of Qwen in default + model = HuggingFaceModelGenerator(model_path=model_path, free_gpus=gpu, temperature=temperature) + inputs = [{"prompt": prompt, "multi_modal_data": {modality: data}} for prompt, data in zip(prompts, data)] + # generate responses + response_ret = [] + for in_data in inputs: + responses = await model.generate(prompt=in_data, num_samples=3) + # print response + response_per_prompt = [] + for response, confidence in responses: + response_per_prompt.append(response) + response_ret.append(response_per_prompt) + return response_ret + + +# run main +@pytest.mark.unittest +def test_single_main(): + # set a temperature > 0 to get multiple responses + free_gpus = get_free_gpus() + model_path = 'Qwen/Qwen2-VL-7B' + temperature = 0.5 + questions = [] + img_names = [] + sample_num = 4 + hf_dataset = load_dataset("MMInstruction/VL-RewardBench", split='test') + for i in range(sample_num): + img_names.append(hf_dataset[i]["image"]) + questions.append(hf_dataset[i]["query"]) + assert len(img_names) == len(questions) + modality = Modality.IMAGE + mm_input = get_multi_modal_input(modality, img_names, questions) + data = mm_input["data"] + question = mm_input["question"] + prompts, stop_token_ids = get_prompts_qwen(question, modality) + responses = asyncio.run( + single_main( + model_path=model_path, + gpu=[free_gpus[0]], + temperature=temperature, + modality=modality.value, + prompts=prompts, + data=data + ) + ) + assert len(responses) == len(questions) + + +# -----------------testing multi gpu vllm_actor -------------------------------- +async def run_vllm_collector(gpu_list: list, prompts: List, model_path: str, temperature: float) -> List[str]: + # set visible gpu + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_list)) + # get a model on a single gpu + model = HuggingFaceModelGenerator(model_path, free_gpus=gpu_list, temperature=temperature) + + # get response for each prompts (can be improved later using async generation) + responses_list = [] + for prompt in prompts: + responses = await model.generate(prompt, num_samples=3) + for response in responses: + responses_list.append(response) + #print(f"[GPU {gpu_list}] Response: {response}") + + return responses_list + + +def start_collector(gpu_list: list, prompts: list, model_path: str, temperature: float) -> List[str]: + # event loop in a process + results = asyncio.run(run_vllm_collector(gpu_list, prompts, model_path, temperature)) + return results + + +def multi_main( + prompts: list, model_path: str, free_gpus: List[int], temperature: float, num_per_gpus_collector: int +) -> None: + # solve how mant collectors to use + num_collector = len(free_gpus) // num_per_gpus_collector + # slove how many gpus a collector should use + gpus_per_collector = chunk_list(free_gpus, num_collector) + # split input_prompts to collectors equally + prompts_per_gpu = chunk_list(prompts, num_collector) + with concurrent.futures.ProcessPoolExecutor(max_workers=num_collector) as executor: + futures = [] + for gpu_list, prompts_gpu in zip(gpus_per_collector, prompts_per_gpu): + futures.append(executor.submit(start_collector, gpu_list, prompts_gpu, model_path, temperature)) + + # get all results + all_results = [] + for future in concurrent.futures.as_completed(futures): + all_results.append(future.result()) + + return all_results + + +@pytest.mark.unittest +def test_multi_main(): + # get dataset + hf_dataset = load_dataset("MMInstruction/VL-RewardBench", split='test') + img_names = [] + questions = [] + num = 16 + for i in range(num): + img_names.append(hf_dataset[i]["image"]) + questions.append(hf_dataset[i]["query"]) + assert len(img_names) == len(questions) + #get gpus + free_gpus = get_free_gpus() + # set modality + modality = Modality.IMAGE + # get input + mm_input = get_multi_modal_input(modality, img_names, questions) + data = mm_input["data"] + question = mm_input["question"] + # get prompts + prompts, stop_token_ids = get_prompts_qwen(question, modality) + # set necessary parameters + model_path = 'Qwen/Qwen2-VL-7B' + temperature = 0.5 + num_gpus_per_collector = 1 + assert len(free_gpus) >= num_gpus_per_collector + # set inputs + inputs = [{"prompt": prompt, "multi_modal_data": {modality.value: data}} for prompt, data in zip(prompts, data)] + # get results + result = multi_main(inputs, model_path, free_gpus, temperature, num_gpus_per_collector) + # default num_smaples is 3, can be modified in line 93 + assert len(result) == len(questions) diff --git a/ding/worker/collector/tests/test_vllm_collector_multi_new.py b/ding/worker/collector/tests/test_vllm_collector_multi_new.py new file mode 100644 index 0000000000..0a255d0624 --- /dev/null +++ b/ding/worker/collector/tests/test_vllm_collector_multi_new.py @@ -0,0 +1,161 @@ +from transformers import AutoTokenizer +from typing import List, Tuple, Optional, Any +import os +from easydict import EasyDict +from datasets import load_dataset +from ding.worker.collector.vllm_collector import VllmCollector, get_free_gpus +import copy +import concurrent.futures +import pytest + + +def chunk_list(original_list: List, t: int) -> List[List]: + # chunk a list into sub_lists + # base length of sublists + base_length = len(original_list) // t + # remaind length of some sub_lists + remainder = len(original_list) % t + new_list = [] + index = 0 + for i in range(t): + if i < remainder: + sublist_length = base_length + 1 + else: + sublist_length = base_length + new_list.append(original_list[index:index + sublist_length]) + index += sublist_length + return new_list + + +# prepare dataset +IMG_START_TOKEN = '<|vision_start|>' +IMG_END_TOKEN = '<|vision_end|>' +PLACE_HOLDER = '<|image_pad|>' + + +def dataset(num: int = None) -> List: + # Load the dataset + hf_dataset = load_dataset("MMInstruction/VL-RewardBench", split='test') + hf_dataset0 = hf_dataset.map( + lambda x: { + "query": f"{IMG_START_TOKEN}{PLACE_HOLDER}{IMG_END_TOKEN}{x['query']}", + "image": x["image"], + } + ) + # shuffle the dataset + hf_dataset = hf_dataset0.shuffle(seed=42) + if num is None: + return hf_dataset + else: + ret_data = [] + for i in range(0, num): + ret_data.append(hf_dataset[i]) + return ret_data + + +def run_vllm_collector(config: EasyDict) -> List[dict]: + ''' + ret:[ + { + "prompt_i":output([output_text_0,output_text_1,...,]) + } + ] + ''' + # set GPU for current process + gpu_ids = ",".join(map(str, config.free_gpus)) + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids + collector = VllmCollector(config) + #ret=collector.collect(n_samples=2,num_samples_per_prompt=4) + ret = collector.collect(n_samples=config.n_samples, num_samples_per_prompt=config.num_samples_per_prompt) + return ret + + +def start_collector(config: EasyDict): + # collect within the process + # results:a dict, basic form: + #{"prompt_0":[ans_0,ans_1,...,ans_n],"prompt_1":[ans_0,ans_1,...,ans_n],...} + results = run_vllm_collector(config) + return results + + +def multi_vllm_main(tot_dataset, free_gpus: list, config: EasyDict): + ''' + tot_dataset: the total dataset to process + free_gpus: list of total gpus available for the task + config: user defined config about how to do the task + ''' + num_gpu_per_collector = config.num_gpus_per_collector + # how many collector to use + num_collector = len(free_gpus) // num_gpu_per_collector + # list of list, each list contains the gpus the collecor can use + gpu_per_collector = chunk_list(free_gpus, num_collector) + prompts_per_gpu = chunk_list(tot_dataset, num_collector) + with concurrent.futures.ProcessPoolExecutor(max_workers=num_collector) as executor: + futures = [] + for gpu_list, prompts_per_collector in zip(gpu_per_collector, prompts_per_gpu): + config_per_gpu = copy.deepcopy(config) + config_per_gpu.dataset = prompts_per_collector + config_per_gpu.free_gpus = gpu_list + #config_per_gpu.n_samples = len(prompts_per_collector) + config_per_gpu.n_samples = 2 + futures.append(executor.submit(start_collector, config_per_gpu)) + + # collect all results + all_results = [] + for future in concurrent.futures.as_completed(futures): + all_results.append(future.result()) + return all_results + + # # save results + # with open(config.save_path, "w") as f: + # for response in all_results: + # #print(response) + # for prompt in list(response.keys()): + # f.write(f"{prompt}:\n") + # for i, output in enumerate(response[prompt].outputs): + # f.write(f'output_{i}:\n') + # f.write(f"{output.text}\n") + + +@pytest.mark.unittest +def test_multi_vllm(): + test_dataset = dataset(num=16) + free_gpus = get_free_gpus() + config = EasyDict( + # (str) LLM/VLM model path + model_path='Qwen/Qwen2-VL-7B', + # (int) Maximum number of tokens to generate per request + max_tokens=4096, + # (float) Temperature for sampling, 0 means greedy decoding + temperature=1.0, + # (dict) Multimodal processor kwargs for vision-language models + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + }, # defaul set to align with Qwen2-VL-7B + # Dataset related configs + # dataset=test_dataset, + # dataset is defined for each gpu respectively + # (str) Key to access the input data in the dataset + input_key='query', + # (bool) Whether to apply a chat template to the input + apply_chat_template=True, + # (str) Template for the input + input_template=None, + # (bool) Whether to shuffle the dataset + shuffle=True, + extra_input_keys=['image'], + # free_gpus is defined for each gpu respectively + # save_path is the file to store the output + save_path="your_path", + # how many gpus a collector can use + num_gpus_per_collector=1, + num_samples_per_prompt=4 + ) + result = multi_vllm_main(test_dataset, free_gpus, config) + collector_num = len(free_gpus) // config.num_gpus_per_collector + assert len(result) == collector_num + for response in result: + prompts = list(response.keys()) + for prompt in prompts: + assert config.num_samples_per_prompt == len(response[prompt].outputs) diff --git a/ding/worker/collector/tests/test_vllm_collector_multigpu.py b/ding/worker/collector/tests/test_vllm_collector_multigpu.py new file mode 100644 index 0000000000..966171f523 --- /dev/null +++ b/ding/worker/collector/tests/test_vllm_collector_multigpu.py @@ -0,0 +1,160 @@ +from typing import List, Tuple, Optional +import os +from vllm.assets.image import ImageAsset +from enum import Enum +from ding.worker.collector.vllm_collector import HuggingFaceModelGenerator, get_free_gpus +from PIL import Image +from datasets import load_dataset +import concurrent.futures +import asyncio +import pytest + + +def chunk_list(original_list: List, t: int): + # chunk a list into sub_lists + # base length of sublists + base_length = len(original_list) // t + # remaind length of some sub_lists + remainder = len(original_list) % t + new_list = [] + index = 0 + for i in range(t): + if i < remainder: + sublist_length = base_length + 1 + else: + sublist_length = base_length + new_list.append(original_list[index:index + sublist_length]) + index += sublist_length + return new_list + + +class Modality(Enum): + IMAGE = "image" + TEXT = "text" + VIDEO = "video" + + +def get_prompts_qwen(questions: list, modality: Modality) -> Tuple[List[str], Optional[List[int]]]: + if modality == Modality.IMAGE: + placeholder = "<|image_pad|>" + elif modality == Modality.VIDEO: + placeholder = "<|video_pad|>" + else: + msg = f"Modality {modality} is not supported." + raise ValueError(msg) + + prompts = [ + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) for question in questions + ] + stop_token_ids = None + return prompts, stop_token_ids + + +def get_multi_modal_input(modality: Modality, filenames: list, questions: list) -> dict: + """ + return { + "data": image or video, + "question": question, + } + """ + if modality == Modality.IMAGE: + # Input image and question + ret = {'data': [], 'question': []} + for filename, question in zip(filenames, questions): + if isinstance(filename, str): + image = ImageAsset(filename) \ + .pil_image.convert("RGB") + #img_question = "What is the content of this image?" + elif isinstance(filename, Image.Image): + image = filename + else: + raise ValueError(f"Unsupported type in filenames: {type(filename)}") + img_question = question + ret["data"].append(image) + ret["question"].append(img_question) + else: + msg = f"Modality {modality} is not supported." + raise ValueError(msg) + return ret + + +async def run_vllm_collector(gpu_list: list, prompts: List, model_path: str, temperature: float) -> List[str]: + # set visible gpu + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_list)) + # get a model on a single gpu + model = HuggingFaceModelGenerator(model_path, free_gpus=gpu_list, temperature=temperature) + + # get response for each prompts (can be improved later using async generation) + responses_list = [] + for prompt in prompts: + responses = await model.generate(prompt, num_samples=3) + for response in responses: + responses_list.append(response) + #print(f"[GPU {gpu_list}] Response: {response}") + + return responses_list + + +def start_collector(gpu_list: list, prompts: list, model_path: str, temperature: float) -> List[str]: + # event loop in a process + results = asyncio.run(run_vllm_collector(gpu_list, prompts, model_path, temperature)) + return results + + +def main(prompts: list, model_path: str, free_gpus: List[int], temperature: float, num_per_gpus_collector: int) -> None: + # solve how mant collectors to use + num_collector = len(free_gpus) // num_per_gpus_collector + # slove how many gpus a collector should use + gpus_per_collector = chunk_list(free_gpus, num_collector) + # split input_prompts to collectors equally + prompts_per_gpu = chunk_list(prompts, num_collector) + with concurrent.futures.ProcessPoolExecutor(max_workers=num_collector) as executor: + futures = [] + for gpu_list, prompts_gpu in zip(gpus_per_collector, prompts_per_gpu): + futures.append(executor.submit(start_collector, gpu_list, prompts_gpu, model_path, temperature)) + + # get all results + all_results = [] + for future in concurrent.futures.as_completed(futures): + all_results.append(future.result()) + + return all_results + + +@pytest.mark.unittest +def test_main(): + # get dataset + hf_dataset = load_dataset("MMInstruction/VL-RewardBench", split='test') + img_names = [] + questions = [] + num = 16 + for i in range(num): + img_names.append(hf_dataset[i]["image"]) + questions.append(hf_dataset[i]["query"]) + assert len(img_names) == len(questions) + #get gpus + free_gpus = get_free_gpus() + # set modality + modality = Modality.IMAGE + # get input + mm_input = get_multi_modal_input(modality, img_names, questions) + data = mm_input["data"] + question = mm_input["question"] + # get prompts + prompts, stop_token_ids = get_prompts_qwen(question, modality) + # set necessary parameters + model_path = 'Qwen/Qwen2-VL-7B' + temperature = 0.5 + num_gpus_per_collector = 1 + assert len(free_gpus) >= num_gpus_per_collector + # set inputs + inputs = [{"prompt": prompt, "multi_modal_data": {modality.value: data}} for prompt, data in zip(prompts, data)] + # get results + result = main(inputs, model_path, free_gpus, temperature, num_gpus_per_collector) + # default num_smaples is 3, can be modified in line 93 + assert len(result) == len(questions) diff --git a/ding/worker/collector/vllm_collector.py b/ding/worker/collector/vllm_collector.py new file mode 100644 index 0000000000..eefe35c33b --- /dev/null +++ b/ding/worker/collector/vllm_collector.py @@ -0,0 +1,438 @@ +from typing import List, Tuple, Optional, Any +import os +import uuid +import asyncio +import numpy as np +from loguru import logger +from easydict import EasyDict +from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, RequestOutput +from transformers import AutoTokenizer + +from ding.utils.data.rlhf_online_dataset import OnlineRLDataset +from ding.utils import SERIAL_COLLECTOR_REGISTRY +from .base_serial_collector import ISerialCollector + + +def get_free_gpus() -> List[int]: + """ + Overview: + Get IDs of GPUs with free memory. + Returns: + - List[int]: The IDs of the free GPUs. + """ + try: + # Get GPU memory usage using nvidia-smi + gpu_stats = os.popen('nvidia-smi --query-gpu=memory.used,memory.total --format=csv,nounits,noheader')\ + .readlines() + free_gpus = [] + + for gpu_id, stats in enumerate(gpu_stats): + mem_used, mem_total = map(int, stats.strip().split(',')) + # Consider GPU as free if less than 5% memory is used + if mem_used / mem_total < 0.05: + free_gpus.append(gpu_id) + + return free_gpus if free_gpus else [0] # Default to GPU 0 if no free GPUs found + except Exception: + logger.warning("Failed to get GPU stats, defaulting to GPU 0") + return [0] + + +class VllmActor: + + def __init__(self, model_path: str, mm_processor_kwargs: dict, free_gpus: list = None) -> None: + """ + Overview: + Initialize the vLLM actor. For more details, please refer to https://docs.vllm.ai/en/stable. + Arguments: + - model_path (str): The path to the language model. + - mm_processor_kwargs(dict): Multimodal processor kwargs for vision-language models + - free_gpus(list): gpus for the model + """ + if free_gpus is None: + self.free_gpus = get_free_gpus() + else: + self.free_gpus = free_gpus + self.num_gpus = len(self.free_gpus) + assert self.num_gpus > 0, "No GPUs found" + # Set CUDA_VISIBLE_DEVICES to use only free GPUs + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, self.free_gpus)) + self.model_path = model_path + self.mm_processor_kwargs = mm_processor_kwargs + self._initialize() + + def _initialize(self) -> None: + """ + Overview: + Initialize the vLLM actor with a series of arguments. + """ + logger.info("Initializing vLLM") + # TODO: Try other options in https://docs.vllm.ai/en/stable/models/engine_args.html#engine-args. + engine_args = AsyncEngineArgs( + model=self.model_path, + tensor_parallel_size=self.num_gpus, + max_num_batched_tokens=8192, + max_model_len=8192, + # enable_chunked_prefill=True, + max_num_seqs=5, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs=self.mm_processor_kwargs, + ) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + + async def generate(self, prompt, num_samples: int, max_tokens: int, temperature: float = 0) -> RequestOutput: + """ + Overview: + Generate tactics for the current state. + Arguments: + - prompt : The prompt to generate tactics. + - num_samples (int): The number of tactics to generate. + - max_tokens (int): The maximum number of tokens to generate. + - temperature (float): The temperature for the language model, default to 0. + Returns: + - RequestOutput: The generated tactics and their log-probabilities. + """ + sampling_params = SamplingParams( + n=num_samples, + max_tokens=max_tokens, + temperature=temperature, + ) + + # Using async iterator to handle vLLM's generation process + # 1. vLLM's generate method is asynchronous to prevent blocking while waiting for model outputs + # 2. async for allows streaming the generated outputs incrementally instead of waiting for all results + # 3. This approach is particularly suitable for LLM inference which can be time-consuming + # 4. The request_id ensures unique identification for each generation request + async for oup in self.engine.generate( + prompt, sampling_params, request_id=str(uuid.uuid4().hex) + ): + final_output = oup + return final_output + + +class HuggingFaceModelGenerator: + """ + Overview: + A LLM/VLM generator that uses Hugging Face models with vLLM as the backend. + """ + + def __init__( + self, + model_path: str, + free_gpus: list, + max_tokens: int = 1024, + temperature: float = 0, + mm_processor_kwargs: dict = { + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + } + ) -> None: + """ + Overview: + Initialize the Hugging Face model generator. + Arguments: + - model_path (str): The path to the language model. + - max_tokens (int): The maximum number of tokens to generate, default to 1024. + - temperature (float): The temperature for the language model, default to 0. + """ + self.vllm_actor = VllmActor(model_path, mm_processor_kwargs, free_gpus) + self.max_tokens = max_tokens + self.temperature = temperature + + async def generate( + self, + prompt, + num_samples: int, + ) -> List[Tuple[str, float]]: + """ + Overview: + Generate tactics for the current state. + Arguments: + - prompt : The prompt to generate tactics. + - num_samples (int): The number of tactics to generate. + Returns: + - List[Tuple[str, float]]: The generated tactics and their log-probabilities. + + .. note:: + This method is asynchronous and returns a coroutine. + """ + response = await self.vllm_actor.generate(prompt, num_samples, self.max_tokens, self.temperature) + # Use raw logprobs as confidence scores + confidence_scores = [x.cumulative_logprob for x in response.outputs] + return [(x.text.strip(), conf) for x, conf in zip(response.outputs, confidence_scores)] + + +@SERIAL_COLLECTOR_REGISTRY.register('vllm') +class VllmCollector(ISerialCollector): + """ + Overview: + Collector implementation for vLLM-based language models (LLM/VLM). + This collector manages the interaction with vLLM models for text generation tasks. + """ + config = dict( + # (str) LLM/VLM model path + model_path='', + # (int) Maximum number of tokens to generate per request + max_tokens=1024, + # (float) Temperature for sampling, 0 means greedy decoding + temperature=0.0, + # (dict) Multimodal processor kwargs for vision-language models + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + }, + # Dataset related configs + # (str) Key to access the input data in the dataset + input_key='input', + # (bool) Whether to apply a chat template to the input + apply_chat_template=False, + # (str) Template for the input + input_template=None, + # (bool) Whether to shuffle the dataset + shuffle=True, + ) + + def __init__(self, cfg: EasyDict) -> None: + """ + Overview: + Initialize the VllmCollector with configuration. + Arguments: + - cfg (:obj:`EasyDict`): Configuration for the collector including model path, generation parameters, + and dataset configuration + """ + super().__init__() + self._cfg = cfg + self._envstep = 0 + + # Initialize the tokenizer and dataset + self._tokenizer = AutoTokenizer.from_pretrained(cfg.model_path) + self._dataset = OnlineRLDataset( + dataset=cfg.dataset, + tokenizer=self._tokenizer, + input_key=cfg.input_key, + apply_chat_template=cfg.apply_chat_template, + input_template=cfg.input_template, + extra_input_keys=cfg.extra_input_keys + ) + + self._model = VllmActor( + model_path=cfg.model_path, mm_processor_kwargs=cfg.mm_processor_kwargs, free_gpus=cfg.free_gpus + ) + self.reset() + + def reset(self) -> None: + """ + Overview: + Reset the collector, including the dataset index. + """ + self._index = np.arange(len(self._dataset)) + if self._cfg.shuffle: + np.random.shuffle(self._index) + + def reset_policy(self, _model: Optional[str] = None) -> None: + """ + Overview: + Since LLM generation does not require a explicit policy and env, this function is empty. + """ + pass + + def reset_env(self, _env: Optional[Any] = None) -> None: + """ + Overview: + Since LLM generation does not require a explicit policy and env, this function is empty. + """ + pass + + async def _generate_for_prompt(self, prompt: str, num_samples_per_prompt: int) -> List[Tuple[str, float]]: + """ + Overview: + Generate response for the prompt. + Arguments: + - prompt(str) : The prompt to generate tactics. + - num_samples_per_prompt (int): The number of tactics to generate. + Returns: + - List[Tuple[str, float]]: The generated tactics and their log-probabilities. + + """ + return await self._model.generate( + prompt=prompt, + num_samples=num_samples_per_prompt, + max_tokens=self._cfg.max_tokens, + temperature=self._cfg.temperature + ) + + def collect( + self, + n_samples: int = 100, + num_samples_per_prompt: int = 1, + train_iter: int = 0, + ) -> List[Tuple[str, float]]: + """ + Overview: + Collect generated responses from the vLLM model. + Arguments: + - n_samples (:obj:`int`): Number of prompts to generate. + - num_samples_per_prompt (:obj:`int`): Number of samples to generate per prompt. + - train_iter (:obj:`int`): Current training iteration, used for logging. + Returns: + - responses (:obj:`List[Tuple[str, float]]`): List of (generated_text, confidence_score) pairs + """ + if self._model is None: + raise RuntimeError("Model not initialized. Call `reset` method first.") + + prompts = [] + for id in self._index[:n_samples]: + prompts.append(self._dataset[id]) + # recusively update the index + self._index = np.concatenate((self._index[n_samples:], self._index[:n_samples])) + + self._envstep += n_samples + + # Get the current event loop or create a new one + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run the async generate method in the event loop + # Create a list of tasks for each prompt + tasks = [self._generate_for_prompt(prompt, num_samples_per_prompt) for prompt in prompts] + + # Run all tasks concurrently and collect results + results = loop.run_until_complete(asyncio.gather(*tasks)) + + # Map prompts to their corresponding results + responses = {prompt["prompt"]: result for prompt, result in zip(prompts, results)} + + return responses + + def sync_collect( + self, + n_samples: int = 100, + num_samples_per_prompt: int = 1, + train_iter: int = 0, + ) -> List[Tuple[str, float]]: + """ + Overview: + Collect generated responses from the vLLM model. + Arguments: + - n_samples (:obj:`int`): Number of prompts to generate. + - num_samples_per_prompt (:obj:`int`): Number of samples to generate per prompt. + - train_iter (:obj:`int`): Current training iteration, used for logging. + Returns: + - responses (:obj:`List[Tuple[str, float]]`): List of (generated_text, confidence_score) pairs + """ + if self._model is None: + raise RuntimeError("Model not initialized. Call `reset` method first.") + + prompts = [] + for id in self._index[:n_samples]: + prompts.append(self._dataset[id]) + # recusively update the index + self._index = np.concatenate((self._index[n_samples:], self._index[:n_samples])) + + self._envstep += n_samples + + # Get the current event loop or create a new one + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run the async generate method in the event loop + results = {} + for prompt in prompts: + # Run the async generate method in the event loop for each prompt + result = loop.run_until_complete( + self._model.generate( + prompt=prompt, + num_samples=num_samples_per_prompt, + max_tokens=self._cfg.max_tokens, + temperature=self._cfg.temperature + ) + ) + results[prompt['prompt']] = result + + return results + + def collect_prompts( + self, + n_samples: int = 100, + num_samples_per_prompt: int = 1, + train_iter: int = 0, + ) -> List[Tuple[str, float]]: + """ + Overview: + Collect generated responses from the vLLM model. + Arguments: + - n_samples (:obj:`int`): Number of prompts to generate. + - num_samples_per_prompt (:obj:`int`): Number of samples to generate per prompt. + - train_iter (:obj:`int`): Current training iteration, used for logging. + Returns: + - responses (:obj:`List[Tuple[str, float]]`): List of (generated_text, confidence_score) pairs + """ + if self._model is None: + raise RuntimeError("Model not initialized. Call `reset` method first.") + + prompts = [] + for id in self._index[:n_samples]: + prompts.append(self._dataset[id]) + # recusively update the index + self._index = np.concatenate((self._index[n_samples:], self._index[:n_samples])) + + self._envstep += n_samples + + # Get the current event loop or create a new one + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run the async generate method in the event loop + results = {} + tasks = [] + for prompt in prompts: + for _ in range(num_samples_per_prompt): + # Run the async generate method in the event loop for each prompt + tasks.append(self._generate_for_prompt(prompt, num_samples_per_prompt=1)) + results_list = loop.run_until_complete(asyncio.gather(*tasks)) + for i, prompt in enumerate(prompts): + results[prompt['prompt']] = [] + for result in results_list[i * num_samples_per_prompt:(i + 1) * num_samples_per_prompt]: + results[prompt['prompt']].append(result.outputs[0].text) + return results + + @property + def envstep(self) -> int: + """ + Overview: + Get the current environment step count. + Returns: + - count (:obj:`int`): Current environment step count + """ + return self._envstep + + @envstep.setter + def envstep(self, value: int) -> None: + """ + Overview: + Set the current environment step count. + """ + self._envstep = value + + def close(self) -> None: + """ + Overview: + Close the collector. + """ + pass + + def __del__(self) -> None: + """ + Overview: + Destructor for the collector. + """ + self.close() diff --git a/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py b/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py index 82d6c673ec..a80662941a 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py @@ -63,4 +63,3 @@ from ding.entry import serial_pipeline with DDPContext(): serial_pipeline((main_config, create_config), seed=0) - diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py index 144feac1dd..e3aa855afe 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=17, action_shape=6, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py index 545ecf970b..440525a320 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=17, action_shape=6, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py index d48a1fb472..0974735b72 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=17, action_shape=6, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/hopper_medium_expert_iql_config.py b/dizoo/d4rl/config/hopper_medium_expert_iql_config.py index 6aef029c5e..2eebce2771 100644 --- a/dizoo/d4rl/config/hopper_medium_expert_iql_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=11, action_shape=3, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/hopper_medium_iql_config.py b/dizoo/d4rl/config/hopper_medium_iql_config.py index 8f429be268..61dbb5fac3 100644 --- a/dizoo/d4rl/config/hopper_medium_iql_config.py +++ b/dizoo/d4rl/config/hopper_medium_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=11, action_shape=3, - ), learn=dict( data_path=None, diff --git a/dizoo/d4rl/config/hopper_medium_replay_iql_config.py b/dizoo/d4rl/config/hopper_medium_replay_iql_config.py index ad1b222843..df96a84aea 100644 --- a/dizoo/d4rl/config/hopper_medium_replay_iql_config.py +++ b/dizoo/d4rl/config/hopper_medium_replay_iql_config.py @@ -18,7 +18,6 @@ model=dict( obs_shape=11, action_shape=3, - ), learn=dict( data_path=None, diff --git a/setup.py b/setup.py index f3d60222f1..3bc8977b46 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,8 @@ 'einops', 'transformers', 'datasets', + 'loguru', + 'vllm' ], extras_require={ 'test': [