@@ -331,6 +331,7 @@ def export_hf_weights(
331331 cpu : bool = False ,
332332 show_progress : bool = True ,
333333 conversion_tasks : Optional [List [WeightConversionTask ]] = None ,
334+ merge_adapter_weights : bool = True ,
334335 ) -> Iterable ["HFWeightTuple" ]:
335336 """
336337 Export Megatron model weights to HuggingFace format.
@@ -352,6 +353,8 @@ def export_hf_weights(
352353 *Please note that this is an advanced feature and should be used with caution.
353354 The tasks needs to be built with the `get_conversion_tasks` method first and
354355 carefully adjust based on your needs.*
356+ merge_adapter_weights: Whether to gather and merge LoRA adapter weights into the base
357+ tensors during export (defaults to True). Set to False to export only the base tensors.
355358
356359
357360 Yields:
@@ -376,6 +379,35 @@ def export_hf_weights(
376379 cpu = cpu ,
377380 show_progress = show_progress ,
378381 conversion_tasks = conversion_tasks ,
382+ merge_adapter_weights = merge_adapter_weights ,
383+ )
384+
385+ def export_adapter_weights (
386+ self ,
387+ model : list [MegatronModelT ],
388+ cpu : bool = True ,
389+ show_progress : bool = True ,
390+ ) -> Iterable ["HFWeightTuple" ]:
391+ """
392+ Export only adapter weights from a Megatron model without merging them into base tensors.
393+
394+ This is useful when you want to save or inspect LoRA adapters independently from the
395+ underlying pretrained weights.
396+
397+ Args:
398+ model: Megatron model instance or list of instances
399+ cpu: Whether to move tensors to CPU before yielding
400+ show_progress: Display progress bar during export
401+
402+ Yields:
403+ HFWeightTuple: Named tuples of (param_name, weight_tensor) for adapter parameters
404+ """
405+ dispatch_instance = (self ._causal_lm_architecture , self ._get_model_instance (model ))
406+ return model_bridge .stream_adapter_weights_megatron_to_hf (
407+ dispatch_instance ,
408+ model ,
409+ cpu = cpu ,
410+ show_progress = show_progress ,
379411 )
380412
381413 def save_hf_pretrained (
@@ -385,6 +417,7 @@ def save_hf_pretrained(
385417 show_progress : bool = True ,
386418 source_path : Optional [Union [str , Path ]] = None ,
387419 strict : bool = True ,
420+ merge_adapter_weights : bool = True ,
388421 ) -> None :
389422 """
390423 Save a Megatron model in HuggingFace format.
@@ -410,6 +443,7 @@ def save_hf_pretrained(
410443 HuggingFace model with custom modeling files needs to be referenced. If not specified,
411444 the path will be automatically determined from the HuggingFace configuration.
412445 strict: Whether to perform strict validation during weight export
446+ merge_adapter_weights: Whether to gather/merge LoRA adapter weights into base tensors during export.
413447
414448
415449 Example:
@@ -433,10 +467,21 @@ def save_hf_pretrained(
433467 # No distributed training, save artifacts
434468 self .hf_pretrained .save_artifacts (path , original_source_path = source_path )
435469
436- self .save_hf_weights (model , path , show_progress , strict )
470+ self .save_hf_weights (
471+ model ,
472+ path ,
473+ show_progress ,
474+ strict ,
475+ merge_adapter_weights = merge_adapter_weights ,
476+ )
437477
438478 def save_hf_weights (
439- self , model : list [MegatronModelT ], path : str | Path , show_progress : bool = True , strict : bool = True
479+ self ,
480+ model : list [MegatronModelT ],
481+ path : str | Path ,
482+ show_progress : bool = True ,
483+ strict : bool = True ,
484+ merge_adapter_weights : bool = True ,
440485 ) -> None :
441486 """
442487 Save Megatron model weights in HuggingFace safetensors format.
@@ -457,6 +502,7 @@ def save_hf_weights(
457502 model: Megatron model instance or list of instances
458503 path: Directory path where weight files will be saved
459504 show_progress: Display progress bar during export
505+ merge_adapter_weights: Whether to gather/merge LoRA adapter weights into base tensors during export.
460506
461507 Raises:
462508 ValueError: If the state source doesn't support streaming save
@@ -478,7 +524,12 @@ def save_hf_weights(
478524 dist .barrier ()
479525 dispatch_instance = (self ._causal_lm_architecture , self ._get_model_instance (model ))
480526 generator = model_bridge .stream_weights_megatron_to_hf (
481- dispatch_instance , model , self .hf_pretrained , cpu = True , show_progress = show_progress
527+ dispatch_instance ,
528+ model ,
529+ self .hf_pretrained ,
530+ cpu = True ,
531+ show_progress = show_progress ,
532+ merge_adapter_weights = merge_adapter_weights ,
482533 )
483534
484535 # Check if the state source is SafeTensorsStateSource for streaming save.
0 commit comments