-
Notifications
You must be signed in to change notification settings - Fork 925
[WIP]Support ray #6245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP]Support ray #6245
Changes from 7 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| swift sample --config sampling.yaml |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| ray_exp_name: sampling | ||
|
|
||
| use_ray: true | ||
|
|
||
| model: Qwen/Qwen2.5-VL-3B-Instruct | ||
| dataset: tastelikefeet/competition_math#16 | ||
| num_return_sequences: 2 | ||
| max_length: 2048 | ||
| system: "You are a math model, you should **think step by step** carefully, and always consider the basic math principles to avoid making calculating mistakes. Give the final answer wrapped with \\boxed{{}}" | ||
| load_args: false | ||
| sampler_engine: vllm | ||
| max_new_tokens: 768 | ||
| orm_model: math | ||
| prm_model: Qwen/Qwen2.5-Math-PRM-7B | ||
| override_exist_file: true | ||
| num_sampling_per_gpu_batch_size: 4 | ||
| top_p: 1.0 | ||
| temperature: 1.0 | ||
| prm_threshold: 0.8 | ||
| output_file: sampling.jsonl | ||
|
|
||
| device_groups: | ||
| nproc_per_node: 4 | ||
| sample_group: | ||
| device: GPU | ||
| ranks: list(range(0, 2)) | ||
| workers: | ||
| - sampler | ||
| rm_group: | ||
| device: GPU | ||
| ranks: list(range(2, 4)) | ||
| workers: | ||
| - prm | ||
| - orm |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ requests | |
| rouge | ||
| safetensors | ||
| scipy | ||
| omegaconf | ||
| sentencepiece | ||
| simplejson>=3.3.0 | ||
| sortedcontainers>=1.5.9 | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,17 @@ | ||||||||||||||||
| import json | ||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||
| from typing import Optional | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @dataclass | ||||||||||||||||
| class RayArguments: | ||||||||||||||||
|
|
||||||||||||||||
| use_ray: bool = False | ||||||||||||||||
|
|
||||||||||||||||
| ray_exp_name: Optional[str] = None | ||||||||||||||||
|
|
||||||||||||||||
| device_groups: Optional[str] = None | ||||||||||||||||
|
|
||||||||||||||||
| def __post_init__(self): | ||||||||||||||||
| if isinstance(self.device_groups, str): | ||||||||||||||||
| self.device_groups = json.loads(self.device_groups) | ||||||||||||||||
|
Comment on lines
+16
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,25 +1,29 @@ | ||||||||||||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | ||||||||||||||
| import json | ||||||||||||||
| import os | ||||||||||||||
| from copy import deepcopy | ||||||||||||||
| from typing import Any, Dict, List | ||||||||||||||
|
|
||||||||||||||
| import json | ||||||||||||||
| import numpy as np | ||||||||||||||
|
|
||||||||||||||
| from swift.llm import RequestConfig | ||||||||||||||
| from swift.llm.sampling.base import Sampler | ||||||||||||||
| from swift.llm.template.template_inputs import InferRequest | ||||||||||||||
| from swift.ray.base import RayHelper | ||||||||||||||
| from swift.utils import get_logger | ||||||||||||||
| from .utils import get_messages_md5, get_reward | ||||||||||||||
|
|
||||||||||||||
| logger = get_logger() | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @RayHelper.worker(group=['sampler', 'prm', 'orm']) | ||||||||||||||
| class VanillaSampler(Sampler): | ||||||||||||||
|
|
||||||||||||||
| def __init__(self, *args, **kwargs): | ||||||||||||||
| super().__init__(*args, **kwargs) | ||||||||||||||
| self.prepare_sampler() | ||||||||||||||
| self.caches = self.read_cache() | ||||||||||||||
|
Comment on lines
+22
to
+23
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Consider removing the
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| @RayHelper.function(group='sampler') | ||||||||||||||
| def prepare_sampler(self): | ||||||||||||||
| if self.args.sampler_engine == 'pt': | ||||||||||||||
| from swift.llm import PtEngine | ||||||||||||||
| _Engine = PtEngine | ||||||||||||||
|
|
@@ -38,8 +42,8 @@ def __init__(self, *args, **kwargs): | |||||||||||||
| self.infer_engine = _Engine( | ||||||||||||||
| self.args.model, model_type=self.args.model_type, template=self.template, **self.args.engine_kwargs) | ||||||||||||||
| self.infer_engine.strict = False | ||||||||||||||
| self.caches = self.read_cache() | ||||||||||||||
|
|
||||||||||||||
| @RayHelper.function(group='sampler') | ||||||||||||||
| def read_cache(self): | ||||||||||||||
| cache_files = self.args.cache_files | ||||||||||||||
| caches = {} | ||||||||||||||
|
|
@@ -82,6 +86,7 @@ def check_row_valid(rows): | |||||||||||||
| assert not row.get('videos') or all([isinstance(video, str) and video for video in row['videos']]) | ||||||||||||||
| assert not row.get('audios') or all([isinstance(audio, str) and audio for audio in row['audios']]) | ||||||||||||||
|
|
||||||||||||||
| @RayHelper.function(group='sampler', dispatch=lambda n, i, data: ([{'messages': data['messages'][i * len(data['messages']) // n : (i + 1) * len(data['messages']) // n]}], {}), collect='flatten') | ||||||||||||||
| def generate(self, data): | ||||||||||||||
| resp_all = [] | ||||||||||||||
| infer_requests = [] | ||||||||||||||
|
|
@@ -141,6 +146,20 @@ def generate(self, data): | |||||||||||||
| _cur += 1 | ||||||||||||||
| return resp_all | ||||||||||||||
|
|
||||||||||||||
| @RayHelper.function(group='orm', execute='first') | ||||||||||||||
| def get_orm_score(self, infer_requests, ground_truth): | ||||||||||||||
| return get_reward( | ||||||||||||||
| self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), | ||||||||||||||
| threshold=0.0) | ||||||||||||||
|
|
||||||||||||||
| @RayHelper.function(group='prm', execute='first') | ||||||||||||||
| def get_prm_score(self, infer_requests, ground_truth): | ||||||||||||||
| return get_reward( | ||||||||||||||
| self.prm_model, | ||||||||||||||
| infer_requests, | ||||||||||||||
| ground_truths=[ground_truth] * len(infer_requests), | ||||||||||||||
| threshold=self.args.prm_threshold) | ||||||||||||||
|
|
||||||||||||||
| def do_sample(self, data): | ||||||||||||||
| generated = [] | ||||||||||||||
| resp_all = self.generate(data) | ||||||||||||||
|
|
@@ -160,18 +179,13 @@ def do_sample(self, data): | |||||||||||||
| _resps = deepcopy(resps) | ||||||||||||||
| _resps['messages'][-1]['content'] = ground_truth | ||||||||||||||
| infer_requests.append(_resps) | ||||||||||||||
| if self.orm_model is not None: | ||||||||||||||
| orm_score, _orm_mask = get_reward( | ||||||||||||||
| self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0) | ||||||||||||||
| if self.args.orm_model is not None: | ||||||||||||||
| orm_score, _orm_mask = self.get_orm_score(infer_requests, ground_truth) | ||||||||||||||
| else: | ||||||||||||||
| orm_score = np.array([1.0] * len(infer_requests)) | ||||||||||||||
| _orm_mask = np.array([True] * len(infer_requests)) | ||||||||||||||
| if self.prm_model is not None: | ||||||||||||||
| prm_score, _prm_mask = get_reward( | ||||||||||||||
| self.prm_model, | ||||||||||||||
| infer_requests, | ||||||||||||||
| ground_truths=[ground_truth] * len(infer_requests), | ||||||||||||||
| threshold=self.args.prm_threshold) | ||||||||||||||
| if self.args.prm_model is not None: | ||||||||||||||
| prm_score, _prm_mask = self.get_prm_score(infer_requests, ground_truth) | ||||||||||||||
| else: | ||||||||||||||
| prm_score = np.array([1.0] * len(infer_requests)) | ||||||||||||||
| _prm_mask = np.array([True] * len(infer_requests)) | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .base import RayHelper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop iterates through
argv[1:], but accesses elements usingargv[i]andargv[i + 1]. This will cause anIndexErrorwheniis the last index in the loop, asargv[i + 1]will be out of bounds. The loop should iterate through the indices ofargvdirectly.To fix this, iterate through
range(1, len(argv), 2)and adjust the indexing accordingly.