@@ -389,17 +389,18 @@ def get_sam_model(
389389 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
390390 # Overwrites the SAM model by freezing the backbone and allow PEFT.
391391 if peft_kwargs and isinstance (peft_kwargs , dict ):
392+ # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
393+ peft_kwargs .pop ("quantize" , None )
394+
392395 if abbreviated_model_type == "vit_t" :
393396 raise ValueError ("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'." )
394397
395398 sam = custom_models .peft_sam .PEFT_Sam (sam , ** peft_kwargs ).sam
396-
397399 # In case the model checkpoints have some issues when it is initialized with different parameters than default.
398400 if flexible_load_checkpoint :
399401 sam = _handle_checkpoint_loading (sam , model_state )
400402 else :
401403 sam .load_state_dict (model_state )
402-
403404 sam .to (device = device )
404405
405406 predictor = SamPredictor (sam )
@@ -456,13 +457,13 @@ def _handle_checkpoint_loading(sam, model_state):
456457def export_custom_sam_model (
457458 checkpoint_path : Union [str , os .PathLike ], model_type : str , save_path : Union [str , os .PathLike ],
458459) -> None :
459- """Export a finetuned segment anything model to the standard model format.
460+ """Export a finetuned Segment Anything Model to the standard model format.
460461
461462 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
462463
463464 Args:
464465 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
465- model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
466+ model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
466467 save_path: Where to save the exported model.
467468 """
468469 _ , state = get_sam_model (
@@ -476,6 +477,54 @@ def export_custom_sam_model(
476477 torch .save (model_state , save_path )
477478
478479
480+ def export_custom_qlora_model (
481+ checkpoint_path : Union [str , os .PathLike ],
482+ finetuned_path : Union [str , os .PathLike ],
483+ model_type : str ,
484+ save_path : Union [str , os .PathLike ],
485+ ) -> None :
486+ """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
487+
488+ The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
489+
490+ Args:
491+ checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
492+ finetuned_path: The path to the new finetuned model, using QLoRA.
493+ model_type: The Segment Anything Model type corresponding to the checkpoint.
494+ save_path: Where to save the exported model.
495+ """
496+ # Step 1: Get the base SAM model: used to start finetuning from.
497+ _ , sam = get_sam_model (
498+ model_type = model_type , checkpoint_path = checkpoint_path , return_sam = True ,
499+ )
500+
501+ # Step 2: Load the QLoRA-style finetuned model.
502+ ft_state , ft_model_state = _load_checkpoint (finetuned_path )
503+
504+ # Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model.
505+ updated_model_state = {}
506+
507+ # - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint.
508+ for k , v in ft_model_state .items ():
509+ if k .find ("w_b_linear" ) != - 1 or k .find ("w_a_linear" ) != - 1 :
510+ updated_model_state [k ] = v
511+
512+ # - Next, we get all the remaining parameters from the base SAM model.
513+ for k , v in sam .state_dict ().items ():
514+ if k .find ("attn.qkv." ) != - 1 :
515+ k = k .replace ("qkv" , "qkv.qkv_proj" )
516+ updated_model_state [k ] = v
517+ else :
518+
519+ updated_model_state [k ] = v
520+
521+ # - Finally, we replace the old model state with the new one (to retain other relevant stuff)
522+ ft_state ['model_state' ] = updated_model_state
523+
524+ # Step 4: Store the new "state" to "save_path"
525+ torch .save (ft_state , save_path )
526+
527+
479528def get_model_names () -> Iterable :
480529 model_registry = models ()
481530 model_names = model_registry .registry .keys ()
0 commit comments