diff --git a/examples/sampler/ray/sample.sh b/examples/sampler/ray/sample.sh new file mode 100644 index 0000000000..0031098d93 --- /dev/null +++ b/examples/sampler/ray/sample.sh @@ -0,0 +1 @@ +swift sample --config sampling.yaml \ No newline at end of file diff --git a/examples/sampler/ray/sampling.yaml b/examples/sampler/ray/sampling.yaml new file mode 100644 index 0000000000..2e5658a646 --- /dev/null +++ b/examples/sampler/ray/sampling.yaml @@ -0,0 +1,34 @@ +ray_exp_name: sampling + +use_ray: true + +model: Qwen/Qwen2.5-VL-3B-Instruct +dataset: tastelikefeet/competition_math#16 +num_return_sequences: 2 +max_length: 2048 +system: "You are a math model, you should **think step by step** carefully, and always consider the basic math principles to avoid making calculating mistakes. Give the final answer wrapped with \\boxed{{}}" +load_args: false +sampler_engine: vllm +max_new_tokens: 768 +orm_model: math +prm_model: Qwen/Qwen2.5-Math-PRM-7B +override_exist_file: true +num_sampling_per_gpu_batch_size: 4 +top_p: 1.0 +temperature: 1.0 +prm_threshold: 0.8 +output_file: sampling.jsonl + +device_groups: + nproc_per_node: 4 + sample_group: + device: GPU + ranks: list(range(0, 2)) + workers: + - sampler + rm_group: + device: GPU + ranks: list(range(2, 4)) + workers: + - prm + - orm \ No newline at end of file diff --git a/requirements/framework.txt b/requirements/framework.txt index e6127654e4..2a5d2cbbb5 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -27,6 +27,7 @@ requests rouge safetensors scipy +omegaconf sentencepiece simplejson>=3.3.0 sortedcontainers>=1.5.9 diff --git a/swift/cli/main.py b/swift/cli/main.py index 8924b03bab..3c03b9a875 100644 --- a/swift/cli/main.py +++ b/swift/cli/main.py @@ -3,7 +3,8 @@ import os import subprocess import sys -from typing import Dict, List, Optional +import json +from typing import Dict, List, Optional, Any from swift.utils import get_logger @@ -45,6 +46,44 @@ def get_torchrun_args() -> Optional[List[str]]: return torchrun_args +def prepare_config_args(argv): + for i in range(0, len(argv[1:]), 2): + arg_name = argv[i] + arg_value = argv[i + 1] + if arg_name == '--config': + from omegaconf import OmegaConf, DictConfig + from swift.ray import RayHelper + config = OmegaConf.load(arg_value) + + def parse_dict_config(cfg: DictConfig) -> Dict[str, Any]: + result = {} + def _traverse(config: Any, parent_key: str = ""): + if isinstance(config, DictConfig): + for key, value in config.items(): + if key == 'device_groups': + result[key] = json.dumps(OmegaConf.to_container(value)) + else: + current_path = f"{parent_key}.{key}" if parent_key else key + _traverse(value, current_path) + else: + last_key = parent_key.split('.')[-1] if parent_key else "" + result[last_key] = config + + _traverse(cfg) + return result + + cfg = parse_dict_config(config) + for key, value in cfg.items(): + argv.append(f'--{key}') + if not isinstance(value, str): + value = str(value) + argv.append(value) + + argv.pop(i) + argv.pop(i) + break + + def _compat_web_ui(argv): # [compat] method_name = argv[0] @@ -56,11 +95,14 @@ def _compat_web_ui(argv): def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: route_mapping = route_mapping or ROUTE_MAPPING argv = sys.argv[1:] + if 'local-rank' in argv[0]: + argv = argv[1:] _compat_web_ui(argv) method_name = argv[0].replace('_', '-') argv = argv[1:] file_path = importlib.util.find_spec(route_mapping[method_name]).origin torchrun_args = get_torchrun_args() + prepare_config_args(argv) python_cmd = sys.executable if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}: args = [python_cmd, file_path, *argv] diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index bc18ba573b..fb5a55d1df 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -16,6 +16,7 @@ from .model_args import ModelArguments from .quant_args import QuantizeArguments from .template_args import TemplateArguments +from .ray_args import RayArguments logger = get_logger() @@ -52,7 +53,7 @@ def __post_init__(self: 'BaseArguments'): @dataclass class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments, - ModelArguments): + ModelArguments, RayArguments): """ BaseArguments class is a dataclass that inherits from multiple argument classes: GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments, ModelArguments. @@ -173,6 +174,7 @@ def __post_init__(self): QuantizeArguments.__post_init__(self) TemplateArguments.__post_init__(self) DataArguments.__post_init__(self) + RayArguments.__post_init__(self) if self.max_length is None and self.model_info is not None: self.max_length = self.model_info.max_model_len if self.packing and self.packing_length is None: diff --git a/swift/llm/argument/base_args/ray_args.py b/swift/llm/argument/base_args/ray_args.py new file mode 100644 index 0000000000..a84634f596 --- /dev/null +++ b/swift/llm/argument/base_args/ray_args.py @@ -0,0 +1,17 @@ +import json +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class RayArguments: + + use_ray: bool = False + + ray_exp_name: Optional[str] = None + + device_groups: Optional[str] = None + + def __post_init__(self): + if isinstance(self.device_groups, str): + self.device_groups = json.loads(self.device_groups) \ No newline at end of file diff --git a/swift/llm/base.py b/swift/llm/base.py index addd19de26..ed4141ee1d 100644 --- a/swift/llm/base.py +++ b/swift/llm/base.py @@ -8,6 +8,7 @@ from swift.utils import get_logger, parse_args, seed_everything from .argument import BaseArguments from .utils import ProcessorMixin +from swift.ray.base import RayHelper logger = get_logger() @@ -18,6 +19,9 @@ class SwiftPipeline(ABC, ProcessorMixin): def __init__(self, args: Optional[Union[List[str], args_class]] = None): self.args = self._parse_args(args) args = self.args + if self.args.use_ray: + from swift.ray import RayHelper + RayHelper.initialize(self.args.device_groups) if hasattr(args, 'seed'): seed = args.seed + max(getattr(args, 'rank', -1), 0) seed_everything(seed) diff --git a/swift/llm/sampling/base.py b/swift/llm/sampling/base.py index b5967e1234..980176a8ed 100644 --- a/swift/llm/sampling/base.py +++ b/swift/llm/sampling/base.py @@ -2,6 +2,7 @@ from swift.llm import SamplingArguments from swift.plugin import orms, prms +from swift.ray.base import RayHelper from swift.utils import get_logger logger = get_logger() @@ -17,13 +18,15 @@ def __init__(self, input_args: SamplingArguments): self.orm_model = None self._prepare_model_tokenizer() self._prepare_template() - self._prepare_rm() + self._prepare_prm() + self._prepare_orm() def _prepare_model_tokenizer(self): args = self.args _, self.processor = args.get_model_processor(load_model=False) - def _prepare_rm(self): + @RayHelper.function(group='prm') + def _prepare_prm(self): if self.args.prm_model is None: self.prm_model = None logger.warning('prm_model is None.') @@ -33,6 +36,8 @@ def _prepare_rm(self): from swift.llm import PtEngine self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64) + @RayHelper.function(group='orm') + def _prepare_orm(self): if self.args.orm_model is None: self.orm_model = None logger.warning('orm_model is None.') diff --git a/swift/llm/sampling/vanilla_sampler.py b/swift/llm/sampling/vanilla_sampler.py index d8130f7068..f80c3cd624 100644 --- a/swift/llm/sampling/vanilla_sampler.py +++ b/swift/llm/sampling/vanilla_sampler.py @@ -1,25 +1,29 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os from copy import deepcopy -from typing import Any, Dict, List -import json import numpy as np from swift.llm import RequestConfig from swift.llm.sampling.base import Sampler -from swift.llm.template.template_inputs import InferRequest +from swift.ray.base import RayHelper from swift.utils import get_logger from .utils import get_messages_md5, get_reward logger = get_logger() +@RayHelper.worker(group=['sampler', 'prm', 'orm']) class VanillaSampler(Sampler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.prepare_sampler() + self.caches = self.read_cache() + @RayHelper.function(group='sampler') + def prepare_sampler(self): if self.args.sampler_engine == 'pt': from swift.llm import PtEngine _Engine = PtEngine @@ -38,8 +42,8 @@ def __init__(self, *args, **kwargs): self.infer_engine = _Engine( self.args.model, model_type=self.args.model_type, template=self.template, **self.args.engine_kwargs) self.infer_engine.strict = False - self.caches = self.read_cache() + @RayHelper.function(group='sampler') def read_cache(self): cache_files = self.args.cache_files caches = {} @@ -82,6 +86,7 @@ def check_row_valid(rows): assert not row.get('videos') or all([isinstance(video, str) and video for video in row['videos']]) assert not row.get('audios') or all([isinstance(audio, str) and audio for audio in row['audios']]) + @RayHelper.function(group='sampler', dispatch=lambda n, i, data: ([{'messages': data['messages'][i * len(data['messages']) // n : (i + 1) * len(data['messages']) // n]}], {}), collect='flatten') def generate(self, data): resp_all = [] infer_requests = [] @@ -141,6 +146,20 @@ def generate(self, data): _cur += 1 return resp_all + @RayHelper.function(group='orm', execute='first') + def get_orm_score(self, infer_requests, ground_truth): + return get_reward( + self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), + threshold=0.0) + + @RayHelper.function(group='prm', execute='first') + def get_prm_score(self, infer_requests, ground_truth): + return get_reward( + self.prm_model, + infer_requests, + ground_truths=[ground_truth] * len(infer_requests), + threshold=self.args.prm_threshold) + def do_sample(self, data): generated = [] resp_all = self.generate(data) @@ -160,18 +179,13 @@ def do_sample(self, data): _resps = deepcopy(resps) _resps['messages'][-1]['content'] = ground_truth infer_requests.append(_resps) - if self.orm_model is not None: - orm_score, _orm_mask = get_reward( - self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0) + if self.args.orm_model is not None: + orm_score, _orm_mask = self.get_orm_score(infer_requests, ground_truth) else: orm_score = np.array([1.0] * len(infer_requests)) _orm_mask = np.array([True] * len(infer_requests)) - if self.prm_model is not None: - prm_score, _prm_mask = get_reward( - self.prm_model, - infer_requests, - ground_truths=[ground_truth] * len(infer_requests), - threshold=self.args.prm_threshold) + if self.args.prm_model is not None: + prm_score, _prm_mask = self.get_prm_score(infer_requests, ground_truth) else: prm_score = np.array([1.0] * len(infer_requests)) _prm_mask = np.array([True] * len(infer_requests)) diff --git a/swift/ray/__init__.py b/swift/ray/__init__.py new file mode 100644 index 0000000000..f1e1995499 --- /dev/null +++ b/swift/ray/__init__.py @@ -0,0 +1 @@ +from .base import RayHelper \ No newline at end of file diff --git a/swift/ray/base.py b/swift/ray/base.py new file mode 100644 index 0000000000..70bba8e108 --- /dev/null +++ b/swift/ray/base.py @@ -0,0 +1,207 @@ +import functools +import os +from typing import Callable, TypeVar, List, Dict, Literal, Union, Any, Type +import ray +import inspect +from ray.runtime_env import RuntimeEnv +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from swift.llm.argument.base_args.ray_args import RayArguments +from swift.ray.resource_manager import ResourceManager +from swift.utils import find_free_port +from swift.utils.utils import find_node_ip + +T = TypeVar('T') + + +def is_called_from_init(): + stack = inspect.stack() + for frame_info in stack[1:]: + if frame_info.function == '__init__': + return True + return False + + +class RayHelper: + + resource_manager: ResourceManager = None + + worker_cls: Dict = {} + + args: RayArguments = None + + worker_instance: Dict = {} + + initialized = False + + device_groups: Dict[str, Any] = None + + @staticmethod + def initialize(device_groups: Dict[str, Any]): + RayHelper.device_groups = device_groups + ray.init() + if RayHelper.resource_manager is None: + RayHelper.resource_manager = ResourceManager(device_groups) + RayHelper.initialized = True + + @staticmethod + def worker(group: Union[str, List[str]]): + + is_worker = ray.is_initialized() and ray._private.worker.global_worker.mode == ray._private.worker.WORKER_MODE + + def decorator(cls): + if is_worker: + return cls + cls.decorated = True + groups = [group] if isinstance(group, str) else group + _cls = ray.remote(cls) + for g in groups: + RayHelper.worker_cls[g] = _cls + + init_method = cls.__init__ + + @functools.wraps(init_method) + def new_init(self, *args, **kwargs): + if not is_worker: + RayHelper._create_workers(group, *args, **kwargs) + init_method(self, *args, **kwargs) + + cls.__init__ = new_init + + return cls + + return decorator + + @staticmethod + def collect_func(method: Literal['none', 'flatten']): + if method == 'none': + return lambda x: x + elif method == 'flatten': + return lambda x: [item for sublist in x for item in sublist] + + @staticmethod + def function(group: str, dispatch: Literal['slice', 'all'] = 'all', execute: Literal['first', 'all'] = 'all', collect: Literal['none', 'flatten'] = 'none'): + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + + @functools.wraps(func) + def wrapper(self, *args, **kwargs) -> T: + is_worker = ray.is_initialized() and ray._private.worker.global_worker.mode == ray._private.worker.WORKER_MODE + if is_worker: + if not hasattr(self, 'group'): + self.group = os.environ['RAY_SWIFT_GROUP'].split(',') + if group not in self.group: + if is_called_from_init(): + return None + else: + raise ValueError() + else: + return func(self, *args, **kwargs) + else: + if is_called_from_init(): + return None + result = RayHelper.execute_all_sync(group, dispatch, execute, func.__name__, *args, **kwargs) + return RayHelper.collect_func(collect)(result) + return wrapper + + return decorator + + @staticmethod + def execute_all_sync(group, dispatch, execute, method_name: str, *args, **kwargs): + return ray.get(RayHelper.execute_all_async(group, dispatch, execute, method_name, *args, **kwargs)) + + @staticmethod + def execute_all_async(group, dispatch, execute, method_name: str, *args, **kwargs): + workers = RayHelper.worker_instance[group] + length = len(workers) + if execute == 'first': + return getattr(workers[0], method_name).remote(*args, **kwargs) + elif dispatch == 'all': + return [getattr(worker, method_name).remote(*args, **kwargs) for worker in workers] + elif dispatch == 'slice': + result = [] + for i in range(length): + sliced_args = tuple(arg[i] for arg in args) + sliced_kwargs = {k: v[i] for k, v in kwargs.items()} + remote_call = getattr(workers[i], method_name) + result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) + else: + result = [] + for i in range(length): + sliced_args, sliced_kwargs = dispatch(length, i, *args, **kwargs) + remote_call = getattr(workers[i], method_name) + result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) + return result + + @staticmethod + def _create_workers(group: Union[str, List[str]], *args, **kwargs): + nproc_per_node = int(RayHelper.device_groups['nproc_per_node']) + + if isinstance(group, str): + group = [group] + + for _group in group: + if _group in RayHelper.worker_instance: + continue + + worker_cls = RayHelper.worker_cls[_group] + + _config = None + for name, config in RayHelper.device_groups.items(): + if name in RayHelper.resource_manager.possible_keys: + continue + + if _group in config['workers']: + _config = config + break + + assert _config is not None + local_groups = _config['workers'] + world_size = len(_config['ranks']) // nproc_per_node + placement_groups: List[List[Dict]] = RayHelper.resource_manager.resource(_group) + workers = [] + ip, port = None, None + for rank, (pgs, gpu) in enumerate(zip(placement_groups, _config['ranks'])): + deploy_pg = pgs + node_idx = gpu // nproc_per_node + cluster_name = '-'.join(local_groups) + worker_name = cluster_name + '-' + str(rank) + env_vars = os.environ.copy() + env_vars.update({ + "WORLD_SIZE": str(world_size), + "RANK": str(rank), + "LOCAL_RANK": str(0), + "CLUSTER_NAME": cluster_name, + "WORKER_NAME": worker_name, + "CUDA_VISIBLE_DEVICES": str(deploy_pg["gpu_rank"]), + }) + + node_id = RayHelper.resource_manager.nodes[node_idx]['NodeID'] + + @ray.remote + def get_node_address(): + return find_node_ip(), find_free_port() + + if rank == 0: + ip, port = ray.get(get_node_address.remote()) + + env_vars["MASTER_ADDR"] = ip + env_vars["MASTER_PORT"] = str(port) + env_vars["RAY_SWIFT_GROUP"] = ','.join(local_groups) + + runtime_env = RuntimeEnv(env_vars=env_vars) + + worker_options = { + "scheduling_strategy": PlacementGroupSchedulingStrategy(placement_group=deploy_pg["placement_group"]), + "name": worker_name, + "namespace": 'default', + "runtime_env": runtime_env, + "num_cpus": 0.01, + "num_gpus": 0.01, + } + + worker = worker_cls.options(**worker_options).remote(*args, **kwargs) + workers.append(worker) + + for g in local_groups: + RayHelper.worker_instance[g] = workers diff --git a/swift/ray/resource_manager.py b/swift/ray/resource_manager.py new file mode 100644 index 0000000000..5b276b7976 --- /dev/null +++ b/swift/ray/resource_manager.py @@ -0,0 +1,110 @@ +import math +import os +from dataclasses import dataclass, field +from typing import Dict, List, Any + +import ray +from ray.util.placement_group import PlacementGroup + +from swift.utils import find_free_port +from swift.utils.utils import find_node_ip + + +@dataclass +class NodeGroup: + device_count: int + nodes: List[Any] = field(default_factory=list) + + +@ray.remote +def get_node_rank(): + return int(os.environ.get("NODE_RANK", "0")) + + +@ray.remote +def get_node_address(): + return find_node_ip(), find_free_port() + + +class ResourceManager: + + possible_keys = ['nproc_per_node', 'nnodes'] + + def __init__(self, groups: Dict[str, Any]): + nproc_per_node = int(groups['nproc_per_node']) + device_types = set([group['device'] for group in groups.values() if hasattr(group, '__getitem__')]) - {'CPU'} + assert len(device_types) == 1 + device_type = next(iter(device_types)) + all_ranks = [] + last_rank = -1 + for group_name, group in groups.items(): + if group_name in self.possible_keys: + continue + ranks = group['ranks'] + device = group['device'] + if device == 'CPU': + continue + try: + ranks = int(ranks) + ranks = list(range(last_rank+1, last_rank+1+ranks)) + except Exception: + if isinstance(ranks, str): + ranks = eval(ranks) + finally: + all_ranks.extend(ranks) + group['ranks'] = ranks + last_rank = ranks[-1] + + assert len(set(all_ranks)) == len(all_ranks) + groups['nnodes'] = math.ceil(len(all_ranks) / nproc_per_node) + + self.nodes = [] + for node in ray.nodes(): + resource = node["Resources"] + node_gpu_num = int(resource.get(device_type, 0)) + if node_gpu_num >= nproc_per_node: + self.nodes.append(node) + + bundles = [] + for i in range(groups['nnodes']): + node = self.nodes[i] + node_cpu = int(node["Resources"]["CPU"]) + bundles.append({device_type: nproc_per_node, "CPU": node_cpu // 2 + 1}) + + self.placement_groups = [ray.util.placement_group([bundle]) for bundle in bundles] + ray.get([pg.ready() for pg in self.placement_groups]) + + self.node_ranks = ray.get( + [get_node_rank.options(placement_group=pg).remote() for pg in self.placement_groups]) + if self.node_ranks.count(0) > 1: + self.node_ranks = list(range(len(self.placement_groups))) + + self.node2pg: Dict[int, PlacementGroup] = {} + ip, port = None, None + for node_rank, placement_group in zip(self.node_ranks, self.placement_groups): + self.node2pg[node_rank] = placement_group + + self.device_groups = {} + ray_address = str(ray.get_runtime_context().gcs_address) + for group_name, group in groups.items(): + if group_name in self.possible_keys: + continue + ranks = group['ranks'] + local_device_groups = [] + for rank in ranks: + node_rank = rank // nproc_per_node + gpu_rank = rank % nproc_per_node + local_device_groups.append( + dict(node_rank=node_rank, gpu_rank=gpu_rank, + placement_group=self.node2pg[node_rank], ray_address=ray_address) + ) + for worker in group['workers']: + self.device_groups[worker] = local_device_groups + + self.groups = groups + + def resource(self, worker): + return self.device_groups[worker] + + def destroy_placement_group(self): + [ray.util.remove_placement_group(pg) for pg in self.placement_groups] diff --git a/swift/utils/utils.py b/swift/utils/utils.py index c03cedc781..4981421c7f 100644 --- a/swift/utils/utils.py +++ b/swift/utils/utils.py @@ -241,6 +241,12 @@ def get_env_args(args_name: str, type_func: Callable[[str], _T], default_value: return value +def find_node_ip(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + + def find_free_port(start_port: Optional[int] = None, retry: int = 100) -> int: if start_port is None: start_port = 0