@@ -115,7 +115,7 @@ def from_lora_tensors(
115115 weights_mapper : WeightsMapper | None = None ,
116116 ) -> "LoRAModel" :
117117 """Create a LoRAModel from a dictionary of tensors."""
118- pin_memory = str ( device ) == "cpu" and is_pin_memory_available ()
118+
119119 loras : dict [str , LoRALayerWeights ] = {}
120120 for tensor_name , tensor in tensors .items ():
121121 if is_base_embeddding_weights (tensor_name ):
@@ -139,14 +139,8 @@ def from_lora_tensors(
139139 f" with the base model's vocabulary size({ model_vocab_size } )."
140140 )
141141 loras [module_name ].lora_a = tensor .to (device = device , dtype = dtype )
142- if pin_memory :
143- loras [module_name ].lora_a = loras [module_name ].lora_a .pin_memory ()
144142 else :
145143 loras [module_name ].lora_b = tensor .to (device = device , dtype = dtype )
146-
147- if pin_memory :
148- loras [module_name ].lora_b = loras [module_name ].lora_b .pin_memory ()
149-
150144 return cls (lora_model_id , peft_helper .r , loras )
151145
152146 @classmethod
@@ -742,6 +736,32 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
742736 for lora in lora_model .loras .values ():
743737 lora .optimize ()
744738
739+ first_lora : LoRALayerWeights = next (iter (lora_model .loras .values ()))
740+ assert first_lora .lora_a is not None
741+ if isinstance (first_lora .lora_a , list ):
742+ lora_device = next (iter (first_lora .lora_a ))
743+ else :
744+ lora_device = first_lora .lora_a .device
745+ # Execute pin_memory after LoRA weight merging, mainly because:
746+ # 1. Some MoE models have a large number of LoRA weights. If we
747+ # perform # pin_memory immediately after loading weights, the
748+ # overhead is significant.
749+ # 2. The weight packing above (e.g., pack_moe) may invalidate the
750+ # pin_memory allocation, so we execute it after packing.
751+
752+ pin_memory = str (lora_device ) == "cpu" and is_pin_memory_available ()
753+ if pin_memory :
754+ for lora in lora_model .loras .values ():
755+ if isinstance (lora .lora_a , list ):
756+ for index in range (len (lora .lora_a )):
757+ if lora .lora_a [index ] is None :
758+ continue
759+ lora .lora_a [index ] = lora .lora_a [index ].pin_memory ()
760+ lora .lora_b [index ] = lora .lora_b [index ].pin_memory ()
761+ else :
762+ lora .lora_a = lora .lora_a .pin_memory ()
763+ lora .lora_b = lora .lora_b .pin_memory ()
764+
745765 def _get_lora_layer_weights (
746766 self , lora_model : LoRAModel , module_name : str
747767 ) -> LoRALayerWeights | None :
0 commit comments