2525import ray .util .queue
2626import ray .util .scheduling_strategies
2727import torch
28+ import transformers
2829
2930from mergekit .evo .actors import InMemoryMergeEvaluator , OnDiskMergeEvaluator
3031from mergekit .evo .config import EvolMergeConfiguration
@@ -43,6 +44,7 @@ def __init__(
4344 batch_size : Optional [int ] = None ,
4445 task_search_path : Union [str , List [str ], None ] = None ,
4546 model_storage_path : Optional [str ] = None ,
47+ quantization_config : Optional [transformers .BitsAndBytesConfig ] = None ,
4648 ):
4749 self .config = config
4850 self .genome = genome
@@ -51,6 +53,7 @@ def __init__(
5153 self .batch_size = batch_size
5254 self .task_manager = lm_eval .tasks .TaskManager (include_path = task_search_path )
5355 self .model_storage_path = model_storage_path
56+ self .quantization_config = quantization_config
5457 if self .model_storage_path :
5558 os .makedirs (self .model_storage_path , exist_ok = True )
5659
@@ -91,6 +94,7 @@ def __init__(
9194 vllm = vllm ,
9295 batch_size = self .batch_size ,
9396 task_manager = self .task_manager ,
97+ quantization_config = self .quantization_config ,
9498 )
9599 for _ in range (self .num_gpus )
96100 ]
@@ -120,6 +124,7 @@ def __init__(
120124 batch_size : Optional [int ] = None ,
121125 task_manager : Optional [lm_eval .tasks .TaskManager ] = None ,
122126 model_storage_path : Optional [str ] = None ,
127+ quantization_config : Optional [transformers .BitsAndBytesConfig ] = None ,
123128 ):
124129 self .config = config
125130 self .genome = genome
@@ -130,6 +135,7 @@ def __init__(
130135 self .batch_size = batch_size
131136 self .task_manager = task_manager
132137 self .model_storage_path = model_storage_path
138+ self .quantization_config = quantization_config
133139 self ._shutdown = False
134140
135141 async def evaluate_genotype (self , genotype : np .ndarray ):
@@ -159,6 +165,9 @@ async def process_queue(self):
159165
160166 while merged and len (evaluating ) < self .num_gpus :
161167 future_result , merged_path = merged .pop ()
168+ kwargs = {}
169+ if self .quantization_config is not None :
170+ kwargs ["quantization_config" ] = self .quantization_config
162171 evaluating [
163172 evaluate_model_ray .remote (
164173 merged_path ,
@@ -168,6 +177,7 @@ async def process_queue(self):
168177 vllm = self .vllm ,
169178 batch_size = self .batch_size ,
170179 task_manager = self .task_manager ,
180+ ** kwargs ,
171181 )
172182 ] = future_result
173183
@@ -222,6 +232,8 @@ def __init__(
222232 vllm = vllm ,
223233 num_gpus = self .num_gpus ,
224234 task_manager = self .task_manager ,
235+ batch_size = self .batch_size ,
236+ quantization_config = self .quantization_config ,
225237 )
226238 self .actor .process_queue .remote ()
227239
@@ -242,6 +254,7 @@ def evaluate_genotype_serial(
242254 vllm : bool = False ,
243255 batch_size : Optional [int ] = None ,
244256 task_manager : Optional [lm_eval .tasks .TaskManager ] = None ,
257+ quantization_config : Optional [transformers .BitsAndBytesConfig ] = None ,
245258):
246259 pg = ray .util .placement_group ([{"CPU" : 1 , "GPU" : 1 }], strategy = "STRICT_PACK" )
247260 strat = ray .util .scheduling_strategies .PlacementGroupSchedulingStrategy (
@@ -252,6 +265,9 @@ def evaluate_genotype_serial(
252265 )
253266 if not merged_path :
254267 return {"score" : None , "results" : None }
268+ kwargs = {}
269+ if quantization_config is not None :
270+ kwargs ["quantization_config" ] = quantization_config
255271 res = ray .get (
256272 evaluate_model_ray .options (scheduling_strategy = strat ).remote (
257273 merged_path ,
@@ -261,6 +277,7 @@ def evaluate_genotype_serial(
261277 vllm = vllm ,
262278 batch_size = batch_size ,
263279 task_manager = task_manager ,
280+ ** kwargs ,
264281 )
265282 )
266283 ray .util .remove_placement_group (pg )
@@ -292,6 +309,7 @@ def evaluate_genotypes(self, genotypes: List[np.ndarray]) -> List[dict]:
292309 vllm = self .vllm ,
293310 batch_size = self .batch_size ,
294311 task_manager = self .task_manager ,
312+ quantization_config = self .quantization_config ,
295313 )
296314 for x in genotypes
297315 ]
0 commit comments