Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/sampler/ray/sample.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
swift sample --config sampling.yaml
34 changes: 34 additions & 0 deletions examples/sampler/ray/sampling.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
ray_exp_name: sampling

use_ray: true

model: Qwen/Qwen2.5-VL-3B-Instruct
dataset: tastelikefeet/competition_math#16
num_return_sequences: 2
max_length: 2048
system: "You are a math model, you should **think step by step** carefully, and always consider the basic math principles to avoid making calculating mistakes. Give the final answer wrapped with \\boxed{{}}"
load_args: false
sampler_engine: vllm
max_new_tokens: 768
orm_model: math
prm_model: Qwen/Qwen2.5-Math-PRM-7B
override_exist_file: true
num_sampling_per_gpu_batch_size: 4
top_p: 1.0
temperature: 1.0
prm_threshold: 0.8
output_file: sampling.jsonl

device_groups:
nproc_per_node: 4
sample_group:
device: GPU
ranks: list(range(0, 2))
workers:
- sampler
rm_group:
device: GPU
ranks: list(range(2, 4))
workers:
- prm
- orm
1 change: 1 addition & 0 deletions requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ requests
rouge
safetensors
scipy
omegaconf
sentencepiece
simplejson>=3.3.0
sortedcontainers>=1.5.9
Expand Down
44 changes: 43 additions & 1 deletion swift/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os
import subprocess
import sys
from typing import Dict, List, Optional
import json
from typing import Dict, List, Optional, Any

from swift.utils import get_logger

Expand Down Expand Up @@ -45,6 +46,44 @@ def get_torchrun_args() -> Optional[List[str]]:
return torchrun_args


def prepare_config_args(argv):
for i in range(0, len(argv[1:]), 2):
arg_name = argv[i]
arg_value = argv[i + 1]
if arg_name == '--config':
Comment on lines +50 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This loop iterates through argv[1:], but accesses elements using argv[i] and argv[i + 1]. This will cause an IndexError when i is the last index in the loop, as argv[i + 1] will be out of bounds. The loop should iterate through the indices of argv directly.

To fix this, iterate through range(1, len(argv), 2) and adjust the indexing accordingly.

Suggested change
for i in range(0, len(argv[1:]), 2):
arg_name = argv[i]
arg_value = argv[i + 1]
if arg_name == '--config':
for i in range(1, len(argv), 2):
arg_name = argv[i]
if i + 1 < len(argv):
arg_value = argv[i + 1]
else:
break # Handle the case where there is no value for the last argument

from omegaconf import OmegaConf, DictConfig
from swift.ray import RayHelper
config = OmegaConf.load(arg_value)

def parse_dict_config(cfg: DictConfig) -> Dict[str, Any]:
result = {}
def _traverse(config: Any, parent_key: str = ""):
if isinstance(config, DictConfig):
for key, value in config.items():
if key == 'device_groups':
result[key] = json.dumps(OmegaConf.to_container(value))
else:
current_path = f"{parent_key}.{key}" if parent_key else key
_traverse(value, current_path)
else:
last_key = parent_key.split('.')[-1] if parent_key else ""
result[last_key] = config

_traverse(cfg)
return result

cfg = parse_dict_config(config)
for key, value in cfg.items():
argv.append(f'--{key}')
if not isinstance(value, str):
value = str(value)
argv.append(value)

argv.pop(i)
argv.pop(i)
break


def _compat_web_ui(argv):
# [compat]
method_name = argv[0]
Expand All @@ -56,11 +95,14 @@ def _compat_web_ui(argv):
def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None:
route_mapping = route_mapping or ROUTE_MAPPING
argv = sys.argv[1:]
if 'local-rank' in argv[0]:
argv = argv[1:]
_compat_web_ui(argv)
method_name = argv[0].replace('_', '-')
argv = argv[1:]
file_path = importlib.util.find_spec(route_mapping[method_name]).origin
torchrun_args = get_torchrun_args()
prepare_config_args(argv)
python_cmd = sys.executable
if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}:
args = [python_cmd, file_path, *argv]
Expand Down
4 changes: 3 additions & 1 deletion swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .model_args import ModelArguments
from .quant_args import QuantizeArguments
from .template_args import TemplateArguments
from .ray_args import RayArguments

logger = get_logger()

Expand Down Expand Up @@ -52,7 +53,7 @@ def __post_init__(self: 'BaseArguments'):

@dataclass
class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments,
ModelArguments):
ModelArguments, RayArguments):
"""
BaseArguments class is a dataclass that inherits from multiple argument classes:
GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments, ModelArguments.
Expand Down Expand Up @@ -173,6 +174,7 @@ def __post_init__(self):
QuantizeArguments.__post_init__(self)
TemplateArguments.__post_init__(self)
DataArguments.__post_init__(self)
RayArguments.__post_init__(self)
if self.max_length is None and self.model_info is not None:
self.max_length = self.model_info.max_model_len
if self.packing and self.packing_length is None:
Expand Down
17 changes: 17 additions & 0 deletions swift/llm/argument/base_args/ray_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import json
from dataclasses import dataclass
from typing import Optional


@dataclass
class RayArguments:

use_ray: bool = False

ray_exp_name: Optional[str] = None

device_groups: Optional[str] = None

def __post_init__(self):
if isinstance(self.device_groups, str):
self.device_groups = json.loads(self.device_groups)
Comment on lines +16 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device_groups attribute is loaded as a JSON string, but there's no error handling if the string is malformed. If the string is not a valid JSON, json.loads will throw an exception, crashing the program. Add a try-except block to handle potential json.JSONDecodeError exceptions.

Suggested change
if isinstance(self.device_groups, str):
self.device_groups = json.loads(self.device_groups)
try:
self.device_groups = json.loads(self.device_groups)
except json.JSONDecodeError:
print("Error decoding device_groups JSON string.")
self.device_groups = None # or some default value

4 changes: 4 additions & 0 deletions swift/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from swift.utils import get_logger, parse_args, seed_everything
from .argument import BaseArguments
from .utils import ProcessorMixin
from swift.ray.base import RayHelper

logger = get_logger()

Expand All @@ -18,6 +19,9 @@ class SwiftPipeline(ABC, ProcessorMixin):
def __init__(self, args: Optional[Union[List[str], args_class]] = None):
self.args = self._parse_args(args)
args = self.args
if self.args.use_ray:
from swift.ray import RayHelper
RayHelper.initialize(self.args.device_groups)
if hasattr(args, 'seed'):
seed = args.seed + max(getattr(args, 'rank', -1), 0)
seed_everything(seed)
Expand Down
9 changes: 7 additions & 2 deletions swift/llm/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from swift.llm import SamplingArguments
from swift.plugin import orms, prms
from swift.ray.base import RayHelper
from swift.utils import get_logger

logger = get_logger()
Expand All @@ -17,13 +18,15 @@ def __init__(self, input_args: SamplingArguments):
self.orm_model = None
self._prepare_model_tokenizer()
self._prepare_template()
self._prepare_rm()
self._prepare_prm()
self._prepare_orm()

def _prepare_model_tokenizer(self):
args = self.args
_, self.processor = args.get_model_processor(load_model=False)

def _prepare_rm(self):
@RayHelper.function(group='prm')
def _prepare_prm(self):
if self.args.prm_model is None:
self.prm_model = None
logger.warning('prm_model is None.')
Expand All @@ -33,6 +36,8 @@ def _prepare_rm(self):
from swift.llm import PtEngine
self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64)

@RayHelper.function(group='orm')
def _prepare_orm(self):
if self.args.orm_model is None:
self.orm_model = None
logger.warning('orm_model is None.')
Expand Down
40 changes: 27 additions & 13 deletions swift/llm/sampling/vanilla_sampler.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
from copy import deepcopy
from typing import Any, Dict, List

import json
import numpy as np

from swift.llm import RequestConfig
from swift.llm.sampling.base import Sampler
from swift.llm.template.template_inputs import InferRequest
from swift.ray.base import RayHelper
from swift.utils import get_logger
from .utils import get_messages_md5, get_reward

logger = get_logger()


@RayHelper.worker(group=['sampler', 'prm', 'orm'])
class VanillaSampler(Sampler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prepare_sampler()
self.caches = self.read_cache()
Comment on lines +22 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The prepare_sampler and read_cache methods are called immediately after the __init__ method, but they are also decorated with @RayHelper.function. This means that when running in a Ray worker, these methods will not be executed, as the RayHelper.function decorator prevents execution within the worker. This could lead to unexpected behavior, as the sampler might not be properly initialized in the Ray workers.

Consider removing the @RayHelper.function decorator from these methods, or ensure that the initialization logic is correctly handled within the Ray worker context.

Suggested change
self.prepare_sampler()
self.caches = self.read_cache()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prepare_sampler()
self.caches = self.read_cache()


@RayHelper.function(group='sampler')
def prepare_sampler(self):
if self.args.sampler_engine == 'pt':
from swift.llm import PtEngine
_Engine = PtEngine
Expand All @@ -38,8 +42,8 @@ def __init__(self, *args, **kwargs):
self.infer_engine = _Engine(
self.args.model, model_type=self.args.model_type, template=self.template, **self.args.engine_kwargs)
self.infer_engine.strict = False
self.caches = self.read_cache()

@RayHelper.function(group='sampler')
def read_cache(self):
cache_files = self.args.cache_files
caches = {}
Expand Down Expand Up @@ -82,6 +86,7 @@ def check_row_valid(rows):
assert not row.get('videos') or all([isinstance(video, str) and video for video in row['videos']])
assert not row.get('audios') or all([isinstance(audio, str) and audio for audio in row['audios']])

@RayHelper.function(group='sampler', dispatch=lambda n, i, data: ([{'messages': data['messages'][i * len(data['messages']) // n : (i + 1) * len(data['messages']) // n]}], {}), collect='flatten')
def generate(self, data):
resp_all = []
infer_requests = []
Expand Down Expand Up @@ -141,6 +146,20 @@ def generate(self, data):
_cur += 1
return resp_all

@RayHelper.function(group='orm', execute='first')
def get_orm_score(self, infer_requests, ground_truth):
return get_reward(
self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests),
threshold=0.0)

@RayHelper.function(group='prm', execute='first')
def get_prm_score(self, infer_requests, ground_truth):
return get_reward(
self.prm_model,
infer_requests,
ground_truths=[ground_truth] * len(infer_requests),
threshold=self.args.prm_threshold)

def do_sample(self, data):
generated = []
resp_all = self.generate(data)
Expand All @@ -160,18 +179,13 @@ def do_sample(self, data):
_resps = deepcopy(resps)
_resps['messages'][-1]['content'] = ground_truth
infer_requests.append(_resps)
if self.orm_model is not None:
orm_score, _orm_mask = get_reward(
self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0)
if self.args.orm_model is not None:
orm_score, _orm_mask = self.get_orm_score(infer_requests, ground_truth)
else:
orm_score = np.array([1.0] * len(infer_requests))
_orm_mask = np.array([True] * len(infer_requests))
if self.prm_model is not None:
prm_score, _prm_mask = get_reward(
self.prm_model,
infer_requests,
ground_truths=[ground_truth] * len(infer_requests),
threshold=self.args.prm_threshold)
if self.args.prm_model is not None:
prm_score, _prm_mask = self.get_prm_score(infer_requests, ground_truth)
else:
prm_score = np.array([1.0] * len(infer_requests))
_prm_mask = np.array([True] * len(infer_requests))
Expand Down
1 change: 1 addition & 0 deletions swift/ray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base import RayHelper
Loading
Loading