2121import cma
2222import numpy as np
2323import pandas
24+ import ray
2425import torch
2526import tqdm
2627import transformers
4445 BufferedRayEvaluationStrategy ,
4546 SerialEvaluationStrategy ,
4647)
48+ from mergekit .merge import run_merge
4749from mergekit .options import MergeOptions
4850
4951
9395 default = False ,
9496 help = "Allow benchmark tasks as objectives" ,
9597)
98+ @click .option (
99+ "--save-final-model/--no-save-final-model" ,
100+ is_flag = True ,
101+ default = True ,
102+ help = "Save the final merged model" ,
103+ )
104+ @click .option (
105+ "--reshard/--no-reshard" ,
106+ is_flag = True ,
107+ default = True ,
108+ help = "Convert models to single-shard safetensors for faster merge" ,
109+ )
96110def main (
97111 genome_config_path : str ,
98112 max_fevals : int ,
@@ -112,6 +126,8 @@ def main(
112126 wandb_entity : Optional [str ],
113127 task_search_path : List [str ],
114128 allow_benchmark_tasks : bool ,
129+ save_final_model : bool ,
130+ reshard : bool ,
115131):
116132 config = EvolMergeConfiguration .model_validate (
117133 yaml .safe_load (open (genome_config_path , "r" , encoding = "utf-8" ))
@@ -146,21 +162,28 @@ def main(
146162 )
147163
148164 # convert models to single-shard safetensors
149- resharded_models = []
150- resharded_base = None
151- for model in tqdm .tqdm (config .genome .models , desc = "Resharding models" ):
152- resharded_models .append (
153- _reshard_model (
154- model , storage_path , merge_options .lora_merge_cache , trust_remote_code
165+ if reshard :
166+ resharded_models = []
167+ resharded_base = None
168+ for model in tqdm .tqdm (config .genome .models , desc = "Resharding models" ):
169+ resharded_models .append (
170+ _reshard_model (
171+ model ,
172+ storage_path ,
173+ merge_options .lora_merge_cache ,
174+ trust_remote_code ,
175+ )
155176 )
156- )
157- if config .genome .base_model is not None :
158- resharded_base = _reshard_model (
159- config .genome .base_model ,
160- storage_path ,
161- merge_options .lora_merge_cache ,
162- trust_remote_code ,
163- )
177+ if config .genome .base_model is not None :
178+ resharded_base = _reshard_model (
179+ config .genome .base_model ,
180+ storage_path ,
181+ merge_options .lora_merge_cache ,
182+ trust_remote_code ,
183+ )
184+ else :
185+ resharded_models = config .genome .models
186+ resharded_base = config .genome .base_model
164187
165188 genome = ModelGenome (
166189 ModelGenomeDefinition .model_validate (
@@ -289,16 +312,22 @@ def parallel_evaluate(x: List[np.ndarray]) -> List[float]:
289312 )
290313 xbest_cost = es .result .fbest
291314 except KeyboardInterrupt :
292- pass
315+ ray . shutdown ()
293316
294317 print ("!!! OPTIMIZATION COMPLETE !!!" )
295318 print (f"Best cost: { xbest_cost :.4f} " )
296319 print ()
297320
298- best_config = genome .genotype_merge_config (xbest )
321+ # save the best merge configuration using original model references
322+ genome_pretty = ModelGenome (config .genome , trust_remote_code = trust_remote_code )
323+ best_config = genome_pretty .genotype_merge_config (xbest )
299324 print ("Best merge configuration:" )
300325 print (best_config .to_yaml ())
301326
327+ if save_final_model :
328+ print ("Saving final model..." )
329+ run_merge (best_config , os .path .join (storage_path , "final_model" ), merge_options )
330+
302331
303332def _reshard_model (
304333 model : ModelReference , storage_path : str , merge_cache : str , trust_remote_code : bool
@@ -322,6 +351,7 @@ def _reshard_model(
322351 revision = merged .model .revision ,
323352 trust_remote_code = trust_remote_code ,
324353 torch_dtype = torch .bfloat16 ,
354+ cache_dir = os .path .join (storage_path , "transformers_cache" ),
325355 )
326356 model_hf .save_pretrained (
327357 out_path , safe_serialization = True , out_shard_size = 1_000_000_000_000
0 commit comments