Skip to content

Commit 269eb63

Browse files
authored
Add --load-in-4bit and --load-in-8bit for HF eval backend (#332)
Allows using bitsandbytes quantization in `mergekit-evolve` when a) not using vLLM and b) not using in-memory mode.
1 parent 84c83f8 commit 269eb63

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

mergekit/evo/actors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
vllm: bool = False,
6363
batch_size: Optional[int] = None,
6464
task_manager: Optional[lm_eval.tasks.TaskManager] = None,
65+
quantization_config: Optional[transformers.BitsAndBytesConfig] = None,
6566
):
6667
self.config = config
6768
self.genome = genome
@@ -72,6 +73,7 @@ def __init__(
7273
self.vllm = vllm
7374
self.batch_size = batch_size
7475
self.task_manager = task_manager
76+
self.quantization_config = quantization_config
7577

7678
if config.shuffle:
7779
monkeypatch_lmeval_shuffle()
@@ -105,6 +107,9 @@ def evaluate_genotype(
105107
logging.error("Model merge failed")
106108
return {"score": None, "results": None}
107109

110+
kwargs = {}
111+
if self.quantization_config is not None:
112+
kwargs["quantization_config"] = self.quantization_config
108113
logging.info(f"Model merged to {merged_path}")
109114
return evaluate_model(
110115
merged_path,
@@ -114,6 +119,7 @@ def evaluate_genotype(
114119
vllm=self.vllm,
115120
batch_size=self.batch_size,
116121
task_manager=self.task_manager,
122+
**kwargs,
117123
)
118124

119125

mergekit/evo/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,15 @@ def evaluate_model(
6767
vllm: bool,
6868
batch_size: Optional[int] = None,
6969
task_manager: Optional[lm_eval.tasks.TaskManager] = None,
70+
**kwargs,
7071
) -> dict:
7172
# monkeypatch_tqdm()
7273
monkeypatch_lmeval_vllm()
7374
try:
7475
model_args = {
7576
"pretrained": merged_path,
7677
"dtype": "bfloat16",
78+
**kwargs,
7779
}
7880
if vllm:
7981
model_args["gpu_memory_utilization"] = 0.8

mergekit/evo/strategy.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import ray.util.queue
2626
import ray.util.scheduling_strategies
2727
import torch
28+
import transformers
2829

2930
from mergekit.evo.actors import InMemoryMergeEvaluator, OnDiskMergeEvaluator
3031
from mergekit.evo.config import EvolMergeConfiguration
@@ -43,6 +44,7 @@ def __init__(
4344
batch_size: Optional[int] = None,
4445
task_search_path: Union[str, List[str], None] = None,
4546
model_storage_path: Optional[str] = None,
47+
quantization_config: Optional[transformers.BitsAndBytesConfig] = None,
4648
):
4749
self.config = config
4850
self.genome = genome
@@ -51,6 +53,7 @@ def __init__(
5153
self.batch_size = batch_size
5254
self.task_manager = lm_eval.tasks.TaskManager(include_path=task_search_path)
5355
self.model_storage_path = model_storage_path
56+
self.quantization_config = quantization_config
5457
if self.model_storage_path:
5558
os.makedirs(self.model_storage_path, exist_ok=True)
5659

@@ -91,6 +94,7 @@ def __init__(
9194
vllm=vllm,
9295
batch_size=self.batch_size,
9396
task_manager=self.task_manager,
97+
quantization_config=self.quantization_config,
9498
)
9599
for _ in range(self.num_gpus)
96100
]
@@ -120,6 +124,7 @@ def __init__(
120124
batch_size: Optional[int] = None,
121125
task_manager: Optional[lm_eval.tasks.TaskManager] = None,
122126
model_storage_path: Optional[str] = None,
127+
quantization_config: Optional[transformers.BitsAndBytesConfig] = None,
123128
):
124129
self.config = config
125130
self.genome = genome
@@ -130,6 +135,7 @@ def __init__(
130135
self.batch_size = batch_size
131136
self.task_manager = task_manager
132137
self.model_storage_path = model_storage_path
138+
self.quantization_config = quantization_config
133139
self._shutdown = False
134140

135141
async def evaluate_genotype(self, genotype: np.ndarray):
@@ -159,6 +165,9 @@ async def process_queue(self):
159165

160166
while merged and len(evaluating) < self.num_gpus:
161167
future_result, merged_path = merged.pop()
168+
kwargs = {}
169+
if self.quantization_config is not None:
170+
kwargs["quantization_config"] = self.quantization_config
162171
evaluating[
163172
evaluate_model_ray.remote(
164173
merged_path,
@@ -168,6 +177,7 @@ async def process_queue(self):
168177
vllm=self.vllm,
169178
batch_size=self.batch_size,
170179
task_manager=self.task_manager,
180+
**kwargs,
171181
)
172182
] = future_result
173183

@@ -222,6 +232,8 @@ def __init__(
222232
vllm=vllm,
223233
num_gpus=self.num_gpus,
224234
task_manager=self.task_manager,
235+
batch_size=self.batch_size,
236+
quantization_config=self.quantization_config,
225237
)
226238
self.actor.process_queue.remote()
227239

@@ -242,6 +254,7 @@ def evaluate_genotype_serial(
242254
vllm: bool = False,
243255
batch_size: Optional[int] = None,
244256
task_manager: Optional[lm_eval.tasks.TaskManager] = None,
257+
quantization_config: Optional[transformers.BitsAndBytesConfig] = None,
245258
):
246259
pg = ray.util.placement_group([{"CPU": 1, "GPU": 1}], strategy="STRICT_PACK")
247260
strat = ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy(
@@ -252,6 +265,9 @@ def evaluate_genotype_serial(
252265
)
253266
if not merged_path:
254267
return {"score": None, "results": None}
268+
kwargs = {}
269+
if quantization_config is not None:
270+
kwargs["quantization_config"] = quantization_config
255271
res = ray.get(
256272
evaluate_model_ray.options(scheduling_strategy=strat).remote(
257273
merged_path,
@@ -261,6 +277,7 @@ def evaluate_genotype_serial(
261277
vllm=vllm,
262278
batch_size=batch_size,
263279
task_manager=task_manager,
280+
**kwargs,
264281
)
265282
)
266283
ray.util.remove_placement_group(pg)
@@ -292,6 +309,7 @@ def evaluate_genotypes(self, genotypes: List[np.ndarray]) -> List[dict]:
292309
vllm=self.vllm,
293310
batch_size=self.batch_size,
294311
task_manager=self.task_manager,
312+
quantization_config=self.quantization_config,
295313
)
296314
for x in genotypes
297315
]

mergekit/scripts/evolve.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@
114114
default=None,
115115
help="Maximum time to run the optimization in seconds",
116116
)
117+
@click.option(
118+
"--load-in-8bit",
119+
is_flag=True,
120+
default=False,
121+
help="Evaluate models at 8-bit precision",
122+
)
123+
@click.option(
124+
"--load-in-4bit",
125+
is_flag=True,
126+
default=False,
127+
help="Evaluate models at 4-bit precision",
128+
)
117129
@click.option(
118130
"--force-population-size",
119131
type=int,
@@ -142,6 +154,8 @@ def main(
142154
save_final_model: bool,
143155
reshard: bool,
144156
timeout: Optional[float],
157+
load_in_8bit: bool,
158+
load_in_4bit: bool,
145159
force_population_size: Optional[int],
146160
):
147161
config = EvolMergeConfiguration.model_validate(
@@ -150,6 +164,29 @@ def main(
150164

151165
check_for_naughty_config(config, allow=allow_benchmark_tasks)
152166

167+
if load_in_4bit and load_in_8bit:
168+
raise ValueError("Cannot load models in both 4-bit and 8-bit")
169+
170+
if load_in_4bit or load_in_8bit:
171+
if vllm:
172+
raise ValueError("Cannot use vLLM with 4-bit or 8-bit models")
173+
if in_memory:
174+
raise ValueError("Cannot use in-memory mode with 4-bit or 8-bit models")
175+
try:
176+
import bitsandbytes
177+
except ImportError:
178+
raise RuntimeError("bitsandbytes is not installed")
179+
180+
bnb_config = transformers.BitsAndBytesConfig(
181+
load_in_8bit=load_in_8bit,
182+
load_in_4bit=load_in_4bit,
183+
bnb_4bit_compute_dtype="bfloat16",
184+
bnb_4bit_quant_type="nf4",
185+
bnb_4bit_use_double_quant=True,
186+
)
187+
else:
188+
bnb_config = None
189+
153190
if use_wandb:
154191
if not wandb:
155192
raise RuntimeError("wandb is not installed")
@@ -235,6 +272,7 @@ def main(
235272
model_storage_path=os.path.join(storage_path, "merged"),
236273
batch_size=batch_size,
237274
task_search_path=task_search_path,
275+
quantization_config=bnb_config,
238276
)
239277

240278
x0 = genome.initial_genotype(random=config.random_init).view(-1).numpy()

0 commit comments

Comments
 (0)