Skip to content

Commit 83a7c11

Browse files
yaoyu-33HollowMan6
andauthored
Peft Bridge (#1766)
Signed-off-by: yaoyu-33 <[email protected]> Signed-off-by: Yu Yao <[email protected]> Signed-off-by: Hollow Man <[email protected]> Co-authored-by: ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟 <[email protected]>
1 parent 0596c92 commit 83a7c11

File tree

13 files changed

+2066
-584
lines changed

13 files changed

+2066
-584
lines changed

examples/conversion/stream_adapter_weights.py

Lines changed: 509 additions & 0 deletions
Large diffs are not rendered by default.

src/megatron/bridge/models/conversion/auto_bridge.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def export_hf_weights(
331331
cpu: bool = False,
332332
show_progress: bool = True,
333333
conversion_tasks: Optional[List[WeightConversionTask]] = None,
334+
merge_adapter_weights: bool = True,
334335
) -> Iterable["HFWeightTuple"]:
335336
"""
336337
Export Megatron model weights to HuggingFace format.
@@ -352,6 +353,8 @@ def export_hf_weights(
352353
*Please note that this is an advanced feature and should be used with caution.
353354
The tasks needs to be built with the `get_conversion_tasks` method first and
354355
carefully adjust based on your needs.*
356+
merge_adapter_weights: Whether to gather and merge LoRA adapter weights into the base
357+
tensors during export (defaults to True). Set to False to export only the base tensors.
355358
356359
357360
Yields:
@@ -376,6 +379,35 @@ def export_hf_weights(
376379
cpu=cpu,
377380
show_progress=show_progress,
378381
conversion_tasks=conversion_tasks,
382+
merge_adapter_weights=merge_adapter_weights,
383+
)
384+
385+
def export_adapter_weights(
386+
self,
387+
model: list[MegatronModelT],
388+
cpu: bool = True,
389+
show_progress: bool = True,
390+
) -> Iterable["HFWeightTuple"]:
391+
"""
392+
Export only adapter weights from a Megatron model without merging them into base tensors.
393+
394+
This is useful when you want to save or inspect LoRA adapters independently from the
395+
underlying pretrained weights.
396+
397+
Args:
398+
model: Megatron model instance or list of instances
399+
cpu: Whether to move tensors to CPU before yielding
400+
show_progress: Display progress bar during export
401+
402+
Yields:
403+
HFWeightTuple: Named tuples of (param_name, weight_tensor) for adapter parameters
404+
"""
405+
dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model))
406+
return model_bridge.stream_adapter_weights_megatron_to_hf(
407+
dispatch_instance,
408+
model,
409+
cpu=cpu,
410+
show_progress=show_progress,
379411
)
380412

381413
def save_hf_pretrained(
@@ -385,6 +417,7 @@ def save_hf_pretrained(
385417
show_progress: bool = True,
386418
source_path: Optional[Union[str, Path]] = None,
387419
strict: bool = True,
420+
merge_adapter_weights: bool = True,
388421
) -> None:
389422
"""
390423
Save a Megatron model in HuggingFace format.
@@ -410,6 +443,7 @@ def save_hf_pretrained(
410443
HuggingFace model with custom modeling files needs to be referenced. If not specified,
411444
the path will be automatically determined from the HuggingFace configuration.
412445
strict: Whether to perform strict validation during weight export
446+
merge_adapter_weights: Whether to gather/merge LoRA adapter weights into base tensors during export.
413447
414448
415449
Example:
@@ -433,10 +467,21 @@ def save_hf_pretrained(
433467
# No distributed training, save artifacts
434468
self.hf_pretrained.save_artifacts(path, original_source_path=source_path)
435469

436-
self.save_hf_weights(model, path, show_progress, strict)
470+
self.save_hf_weights(
471+
model,
472+
path,
473+
show_progress,
474+
strict,
475+
merge_adapter_weights=merge_adapter_weights,
476+
)
437477

438478
def save_hf_weights(
439-
self, model: list[MegatronModelT], path: str | Path, show_progress: bool = True, strict: bool = True
479+
self,
480+
model: list[MegatronModelT],
481+
path: str | Path,
482+
show_progress: bool = True,
483+
strict: bool = True,
484+
merge_adapter_weights: bool = True,
440485
) -> None:
441486
"""
442487
Save Megatron model weights in HuggingFace safetensors format.
@@ -457,6 +502,7 @@ def save_hf_weights(
457502
model: Megatron model instance or list of instances
458503
path: Directory path where weight files will be saved
459504
show_progress: Display progress bar during export
505+
merge_adapter_weights: Whether to gather/merge LoRA adapter weights into base tensors during export.
460506
461507
Raises:
462508
ValueError: If the state source doesn't support streaming save
@@ -478,7 +524,12 @@ def save_hf_weights(
478524
dist.barrier()
479525
dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model))
480526
generator = model_bridge.stream_weights_megatron_to_hf(
481-
dispatch_instance, model, self.hf_pretrained, cpu=True, show_progress=show_progress
527+
dispatch_instance,
528+
model,
529+
self.hf_pretrained,
530+
cpu=True,
531+
show_progress=show_progress,
532+
merge_adapter_weights=merge_adapter_weights,
482533
)
483534

484535
# Check if the state source is SafeTensorsStateSource for streaming save.

0 commit comments

Comments
 (0)