Skip to content

Commit 552dc55

Browse files
caroteuanwai98
andauthored
Fix QLoRA weights and bias initialisation (#833)
Fixes the following: - Loading of pretrained weights to quantized layers. - Converting the QLoRA-style finetuned model to LoRA-style model to run inference. --------- Co-authored-by: Anwai Archit <[email protected]> Co-authored-by: Anwai Archit <[email protected]>
1 parent a47f2f2 commit 552dc55

File tree

2 files changed

+70
-5
lines changed

2 files changed

+70
-5
lines changed

micro_sam/models/peft_sam.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,23 @@ def __init__(
342342
for sub_module in parent_path:
343343
parent_module = getattr(parent_module, sub_module)
344344

345-
setattr(parent_module, layer_name, bnb.nn.Linear4bit(module.in_features, module.out_features))
345+
# Create the new Linear4bit layer
346+
linear_q = bnb.nn.Linear4bit(
347+
module.in_features,
348+
module.out_features,
349+
bias=False if module.bias is None else True,
350+
)
351+
# Assign weights and bias to the new layer
352+
new_weight = bnb.nn.Params4bit(
353+
data=module.weight,
354+
requires_grad=False,
355+
)
356+
linear_q.weight = new_weight
357+
if module.bias is not None:
358+
linear_q.bias = torch.nn.Parameter(module.bias)
359+
360+
# Replace the original linear layer with the quantized one
361+
setattr(parent_module, layer_name, linear_q)
346362

347363
# Let's freeze all the pretrained image encoder layers first
348364
for param in model.image_encoder.parameters():

micro_sam/util.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,17 +389,18 @@ def get_sam_model(
389389
# Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
390390
# Overwrites the SAM model by freezing the backbone and allow PEFT.
391391
if peft_kwargs and isinstance(peft_kwargs, dict):
392+
# NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
393+
peft_kwargs.pop("quantize", None)
394+
392395
if abbreviated_model_type == "vit_t":
393396
raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
394397

395398
sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
396-
397399
# In case the model checkpoints have some issues when it is initialized with different parameters than default.
398400
if flexible_load_checkpoint:
399401
sam = _handle_checkpoint_loading(sam, model_state)
400402
else:
401403
sam.load_state_dict(model_state)
402-
403404
sam.to(device=device)
404405

405406
predictor = SamPredictor(sam)
@@ -456,13 +457,13 @@ def _handle_checkpoint_loading(sam, model_state):
456457
def export_custom_sam_model(
457458
checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike],
458459
) -> None:
459-
"""Export a finetuned segment anything model to the standard model format.
460+
"""Export a finetuned Segment Anything Model to the standard model format.
460461
461462
The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
462463
463464
Args:
464465
checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
465-
model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
466+
model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
466467
save_path: Where to save the exported model.
467468
"""
468469
_, state = get_sam_model(
@@ -476,6 +477,54 @@ def export_custom_sam_model(
476477
torch.save(model_state, save_path)
477478

478479

480+
def export_custom_qlora_model(
481+
checkpoint_path: Union[str, os.PathLike],
482+
finetuned_path: Union[str, os.PathLike],
483+
model_type: str,
484+
save_path: Union[str, os.PathLike],
485+
) -> None:
486+
"""Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
487+
488+
The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
489+
490+
Args:
491+
checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
492+
finetuned_path: The path to the new finetuned model, using QLoRA.
493+
model_type: The Segment Anything Model type corresponding to the checkpoint.
494+
save_path: Where to save the exported model.
495+
"""
496+
# Step 1: Get the base SAM model: used to start finetuning from.
497+
_, sam = get_sam_model(
498+
model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
499+
)
500+
501+
# Step 2: Load the QLoRA-style finetuned model.
502+
ft_state, ft_model_state = _load_checkpoint(finetuned_path)
503+
504+
# Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model.
505+
updated_model_state = {}
506+
507+
# - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint.
508+
for k, v in ft_model_state.items():
509+
if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1:
510+
updated_model_state[k] = v
511+
512+
# - Next, we get all the remaining parameters from the base SAM model.
513+
for k, v in sam.state_dict().items():
514+
if k.find("attn.qkv.") != -1:
515+
k = k.replace("qkv", "qkv.qkv_proj")
516+
updated_model_state[k] = v
517+
else:
518+
519+
updated_model_state[k] = v
520+
521+
# - Finally, we replace the old model state with the new one (to retain other relevant stuff)
522+
ft_state['model_state'] = updated_model_state
523+
524+
# Step 4: Store the new "state" to "save_path"
525+
torch.save(ft_state, save_path)
526+
527+
479528
def get_model_names() -> Iterable:
480529
model_registry = models()
481530
model_names = model_registry.registry.keys()

0 commit comments

Comments
 (0)