Skip to content

Commit 8ee638e

Browse files
authored
mergekit-evolve QoL tweaks (#298)
A few quality of life tweaks for `mergekit-evolve`: 1. Make flash attention optional 2. Write out the final merge instead of just the config (`--no-save-final-model` to disable) 3. Add option to not reshard input models (`--no-reshard`) for lower disk use at cost of merge speed
1 parent b4136ba commit 8ee638e

File tree

7 files changed

+80
-25
lines changed

7 files changed

+80
-25
lines changed

mergekit/evo/actors.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import ray.util.scheduling_strategies
2828
import torch
2929
import transformers
30+
from transformers.utils import is_flash_attn_2_available
3031

3132
try:
3233
import vllm
@@ -39,7 +40,11 @@
3940
from mergekit.evo.config import EvolMergeConfiguration
4041
from mergekit.evo.genome import ModelGenome
4142
from mergekit.evo.helpers import _eval_model, evaluate_model, merge_model
42-
from mergekit.evo.monkeypatch import NoInit, monkeypatch_lmeval_shuffle
43+
from mergekit.evo.monkeypatch import (
44+
NoInit,
45+
monkeypatch_lmeval_shuffle,
46+
monkeypatch_lmeval_vllm,
47+
)
4348
from mergekit.graph import Executor
4449
from mergekit.io.tasks import LoaderCache, ReturnTensor
4550
from mergekit.merge import _model_out_config
@@ -72,6 +77,7 @@ def __init__(
7277
monkeypatch_lmeval_shuffle()
7378

7479
# monkeypatch_tqdm()
80+
monkeypatch_lmeval_vllm()
7581

7682

7783
@ray.remote(num_cpus=1, num_gpus=1.0)
@@ -164,13 +170,18 @@ def _maybe_init_model(self, config: MergeConfiguration):
164170
if not different:
165171
return
166172

173+
model_kwargs = {
174+
"trust_remote_code": self.merge_options.trust_remote_code,
175+
"torch_dtype": torch.bfloat16,
176+
}
177+
if is_flash_attn_2_available():
178+
model_kwargs["attn_implementation"] = "flash_attention_2"
179+
167180
with NoInit():
168181
inner_model = (
169182
transformers.AutoModelForCausalLM.from_config(
170183
cfg_out,
171-
trust_remote_code=self.merge_options.trust_remote_code,
172-
attn_implementation="flash_attention_2",
173-
torch_dtype=torch.bfloat16,
184+
**model_kwargs,
174185
)
175186
.bfloat16()
176187
.cuda()
@@ -203,11 +214,14 @@ def _maybe_init_model(self, config: MergeConfiguration):
203214
max_model_len = 8192
204215
logging.warn(f"Clipping sequence length to {max_model_len}")
205216

217+
mem_util = (
218+
0.7 if self.merge_options.cuda else 0.9
219+
) # reduce memory usage if we're also using cuda for the merge
206220
self.model = lm_eval.models.vllm_causallms.VLLM(
207221
pretrained=tempdir,
208222
batch_size=self.batch_size or "auto",
209223
max_model_len=max_model_len,
210-
gpu_memory_utilization=0.7, # can't do 0.9 because the merge will OOM
224+
gpu_memory_utilization=mem_util,
211225
dtype="bfloat16",
212226
device="cuda",
213227
trust_remote_code=self.merge_options.trust_remote_code,
@@ -279,6 +293,7 @@ def evaluate(self, genotype: torch.Tensor) -> dict:
279293
num_fewshot=self.config.num_fewshot,
280294
limit=self.config.limit,
281295
task_manager=self.task_manager,
296+
batch_size=self.batch_size,
282297
)
283298

284299
def evaluate_genotype(

mergekit/evo/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from mergekit.evo.config import TaskConfiguration
3232
from mergekit.evo.genome import ModelGenome
33+
from mergekit.evo.monkeypatch import monkeypatch_lmeval_vllm
3334
from mergekit.merge import run_merge
3435
from mergekit.options import MergeOptions
3536

@@ -68,6 +69,7 @@ def evaluate_model(
6869
task_manager: Optional[lm_eval.tasks.TaskManager] = None,
6970
) -> float:
7071
# monkeypatch_tqdm()
72+
monkeypatch_lmeval_vllm()
7173
try:
7274
model_args = {
7375
"pretrained": merged_path,

mergekit/evo/monkeypatch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ def _patch_lm_eval():
100100
mergekit.tokenizer.tqdm = fake_module
101101

102102

103+
def monkeypatch_lmeval_vllm():
104+
# HACK: fix crash on some tasks due to unset AUTO_MODEL_CLASS for vLLM
105+
import lm_eval.models.vllm_causallms
106+
107+
lm_eval.models.vllm_causallms.VLLM.AUTO_MODEL_CLASS = (
108+
transformers.AutoModelForCausalLM
109+
)
110+
111+
103112
class NoInit:
104113
def __enter__(self):
105114
def noop(*args, **kwargs):

mergekit/evo/strategy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
self.merge_options,
9090
model_storage_path=self.model_storage_path,
9191
vllm=vllm,
92+
batch_size=self.batch_size,
9293
task_manager=self.task_manager,
9394
)
9495
for _ in range(self.num_gpus)

mergekit/scripts/evolve.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import cma
2222
import numpy as np
2323
import pandas
24+
import ray
2425
import torch
2526
import tqdm
2627
import transformers
@@ -44,6 +45,7 @@
4445
BufferedRayEvaluationStrategy,
4546
SerialEvaluationStrategy,
4647
)
48+
from mergekit.merge import run_merge
4749
from mergekit.options import MergeOptions
4850

4951

@@ -93,6 +95,18 @@
9395
default=False,
9496
help="Allow benchmark tasks as objectives",
9597
)
98+
@click.option(
99+
"--save-final-model/--no-save-final-model",
100+
is_flag=True,
101+
default=True,
102+
help="Save the final merged model",
103+
)
104+
@click.option(
105+
"--reshard/--no-reshard",
106+
is_flag=True,
107+
default=True,
108+
help="Convert models to single-shard safetensors for faster merge",
109+
)
96110
def main(
97111
genome_config_path: str,
98112
max_fevals: int,
@@ -112,6 +126,8 @@ def main(
112126
wandb_entity: Optional[str],
113127
task_search_path: List[str],
114128
allow_benchmark_tasks: bool,
129+
save_final_model: bool,
130+
reshard: bool,
115131
):
116132
config = EvolMergeConfiguration.model_validate(
117133
yaml.safe_load(open(genome_config_path, "r", encoding="utf-8"))
@@ -146,21 +162,28 @@ def main(
146162
)
147163

148164
# convert models to single-shard safetensors
149-
resharded_models = []
150-
resharded_base = None
151-
for model in tqdm.tqdm(config.genome.models, desc="Resharding models"):
152-
resharded_models.append(
153-
_reshard_model(
154-
model, storage_path, merge_options.lora_merge_cache, trust_remote_code
165+
if reshard:
166+
resharded_models = []
167+
resharded_base = None
168+
for model in tqdm.tqdm(config.genome.models, desc="Resharding models"):
169+
resharded_models.append(
170+
_reshard_model(
171+
model,
172+
storage_path,
173+
merge_options.lora_merge_cache,
174+
trust_remote_code,
175+
)
155176
)
156-
)
157-
if config.genome.base_model is not None:
158-
resharded_base = _reshard_model(
159-
config.genome.base_model,
160-
storage_path,
161-
merge_options.lora_merge_cache,
162-
trust_remote_code,
163-
)
177+
if config.genome.base_model is not None:
178+
resharded_base = _reshard_model(
179+
config.genome.base_model,
180+
storage_path,
181+
merge_options.lora_merge_cache,
182+
trust_remote_code,
183+
)
184+
else:
185+
resharded_models = config.genome.models
186+
resharded_base = config.genome.base_model
164187

165188
genome = ModelGenome(
166189
ModelGenomeDefinition.model_validate(
@@ -289,16 +312,22 @@ def parallel_evaluate(x: List[np.ndarray]) -> List[float]:
289312
)
290313
xbest_cost = es.result.fbest
291314
except KeyboardInterrupt:
292-
pass
315+
ray.shutdown()
293316

294317
print("!!! OPTIMIZATION COMPLETE !!!")
295318
print(f"Best cost: {xbest_cost:.4f}")
296319
print()
297320

298-
best_config = genome.genotype_merge_config(xbest)
321+
# save the best merge configuration using original model references
322+
genome_pretty = ModelGenome(config.genome, trust_remote_code=trust_remote_code)
323+
best_config = genome_pretty.genotype_merge_config(xbest)
299324
print("Best merge configuration:")
300325
print(best_config.to_yaml())
301326

327+
if save_final_model:
328+
print("Saving final model...")
329+
run_merge(best_config, os.path.join(storage_path, "final_model"), merge_options)
330+
302331

303332
def _reshard_model(
304333
model: ModelReference, storage_path: str, merge_cache: str, trust_remote_code: bool
@@ -322,6 +351,7 @@ def _reshard_model(
322351
revision=merged.model.revision,
323352
trust_remote_code=trust_remote_code,
324353
torch_dtype=torch.bfloat16,
354+
cache_dir=os.path.join(storage_path, "transformers_cache"),
325355
)
326356
model_hf.save_pretrained(
327357
out_path, safe_serialization=True, out_shard_size=1_000_000_000_000

mergekit/scripts/extract_lora.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from peft.tuners.lora import QuantLinear
1010
from safetensors.torch import save_file
1111
from tqdm import tqdm
12-
from transformers import AutoConfig, AutoModelForCausalLM
12+
from transformers import AutoModelForCausalLM
1313
from transformers.modeling_utils import PreTrainedModel
1414

1515
from mergekit.card import generate_card_lora
@@ -216,8 +216,6 @@ def main(
216216
base_model_ref = ModelReference.parse(base_model)
217217
finetuned_model_ref = ModelReference.parse(finetuned_model)
218218

219-
base_model_config = AutoConfig.from_pretrained(base_model_ref.model.path)
220-
221219
linear_module_names = get_linear_module_names(base_model_ref.model.path)
222220
finetuned_model_linear_module_names = get_linear_module_names(
223221
finetuned_model_ref.model.path

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies = [
2828
[project.optional-dependencies]
2929
dev = ["black~=24.2.0", "isort~=5.13.2", "pre-commit~=3.6.2"]
3030
test = ["pytest~=8.0.1"]
31-
evolve = ["ray", "cma", "lm_eval", "flash-attn", "wandb"]
31+
evolve = ["ray", "cma", "lm_eval", "wandb"]
3232
vllm = ["vllm==0.3.2", "lm_eval[vllm]"]
3333

3434
[project.urls]

0 commit comments

Comments
 (0)