-
Notifications
You must be signed in to change notification settings - Fork 169
[1/N] ModelOPT PEFT mode support for the megatron-lm #342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jingyu-ml
wants to merge
55
commits into
main
Choose a base branch
from
jingyux/megatron-lora
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 46 commits
Commits
Show all changes
55 commits
Select commit
Hold shift + click to select a range
2c1abe0
Fixed the CICD for Diffusion
jingyu-ml c70c778
Merge branch 'main' into jingyux/fixed-trtexec-cicd
jingyu-ml 0b18e50
Update req for diffusers
jingyu-ml adcc046
Add megatron lora support
jingyu-ml 2cf979f
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml ec3c17b
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml e226673
Update
jingyu-ml 31e8875
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml 62c8685
Add more functions
jingyu-ml e9a78b3
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml 744eef8
Update: to support quantize the lora layers
jingyu-ml 206e44f
Update test cases
jingyu-ml 935f524
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml d50bc78
Clean up code
jingyu-ml c22e95d
Clean up code, more
jingyu-ml c0b955a
more clean up
jingyu-ml 8eeb8e5
Update more, config + conversation
jingyu-ml cf91aba
Update disable/enable logic
jingyu-ml a6e19ea
Update restore logic
jingyu-ml 98f2314
Update sharded axis
jingyu-ml 381bf4e
Add disable_adapters enable_adpaters support, removed some codes
jingyu-ml 2617952
Update test cases / update namings
jingyu-ml f32d6ed
Add test cases
jingyu-ml 1b0f424
Update lora implementations
jingyu-ml dfd2810
Force to check the mcore model
jingyu-ml 3823c15
minors on the __init__
jingyu-ml 7cf4ac5
More fix to the init
jingyu-ml 3c08dfc
More fix to the init
jingyu-ml 787f6ff
Update the grad for loras
jingyu-ml a43b6c4
Update
jingyu-ml 9311e04
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml 1d8ba41
minor on the test case
jingyu-ml d4b8a28
Update the codeowners
jingyu-ml ffef564
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml ce6bead
Remove the import error check
jingyu-ml 81f8d06
Update init functions
jingyu-ml 6df1954
Update the comment
jingyu-ml 1bb3985
Some minor updates
jingyu-ml 98ef9fb
Update: removed the permodule restore and state
jingyu-ml f98711e
Some minor updates
jingyu-ml 8df12bc
Update the test case and some minor updates
jingyu-ml 48e9ab5
Update comments for test cases
jingyu-ml 5318241
Update comments for test cases
jingyu-ml 49a1e65
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml 8c31821
Update test case
jingyu-ml 024d57c
Update on the test case
jingyu-ml 98e4e73
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml 22f55e2
Update Changelog
jingyu-ml 03807e2
Update Changelog
jingyu-ml 9b96fea
Update test case for quantize / lora
jingyu-ml 5030b43
Update the grad and some test cases
jingyu-ml 0b310fb
update init functions
jingyu-ml 0b202b9
Merge branch 'main' into jingyux/megatron-lora
jingyu-ml 82dc269
Change name
jingyu-ml fad4982
minor
jingyu-ml File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""PEFT API subpackage for torch.""" | ||
|
||
from . import mode | ||
from .config import * | ||
from .conversion import * | ||
from .convert import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Configuration classes for PEFT methods.""" | ||
|
||
from collections.abc import Callable | ||
|
||
from pydantic import field_validator | ||
|
||
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField | ||
|
||
__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"] | ||
|
||
|
||
class PEFTAttributeConfig(ModeloptBaseConfig): | ||
"""Configuration for PEFT adapter attributes.""" | ||
|
||
enable: bool = ModeloptField( | ||
default=True, | ||
title="Enable adapter", | ||
description="If True, enables the adapter. If False, by-passes the adapter.", | ||
) | ||
|
||
rank: int = ModeloptField( | ||
default=64, | ||
title="LoRA rank", | ||
description=( | ||
"The rank (dimension) of the LoRA matrices. " | ||
"Higher rank allows more expressiveness but uses more memory." | ||
), | ||
) | ||
|
||
scale: float = ModeloptField( | ||
default=1.0, | ||
title="LoRA scaling factor", | ||
description="Scaling factor for the LoRA output. Controls the magnitude of the adaptation.", | ||
) | ||
|
||
lora_a_init: str = ModeloptField( | ||
default="kaiming_init", | ||
title="LoRA A matrix initializer", | ||
description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.", | ||
jingyu-ml marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
) | ||
|
||
lora_b_init: str = ModeloptField( | ||
default="zero_init", | ||
title="LoRA B matrix initializer", | ||
description="Custom initialization function for LoRA B matrix. Default to zero initialization.", | ||
) | ||
jingyu-ml marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
@field_validator("lora_a_init", "lora_b_init") | ||
@classmethod | ||
def validate_init_method(cls, v): | ||
"""Validate initialization method is supported.""" | ||
valid_methods = {"kaiming_init", "zero_init"} | ||
jingyu-ml marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if v not in valid_methods: | ||
raise ValueError( | ||
f"Invalid initialization method: {v}. Supported methods: {', '.join(valid_methods)}" | ||
) | ||
return v | ||
|
||
@field_validator("rank") | ||
@classmethod | ||
def validate_rank(cls, v): | ||
"""Validate rank is positive.""" | ||
if v < 1: | ||
raise ValueError("rank must be a positive integer") | ||
return v | ||
|
||
@field_validator("scale") | ||
@classmethod | ||
def validate_scale(cls, v): | ||
"""Validate scale is positive.""" | ||
if v <= 0: | ||
raise ValueError("scale must be a positive number") | ||
return v | ||
|
||
|
||
# Type alias for adapter configuration | ||
PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict] | ||
|
||
|
||
class PEFTConfig(ModeloptBaseConfig): | ||
"""Default configuration for ``peft`` mode.""" | ||
|
||
adapter_name: str = ModeloptField( | ||
default="default", | ||
title="Adapter name", | ||
description="Name of the adapter to create or update.", | ||
validate_default=True, | ||
) | ||
|
||
adapter_cfg: PEFTAdapterCfgType = ModeloptField( | ||
default={"*": {"rank": 64}}, | ||
title="Adapter configuration", | ||
description="Configuration for adapters. Maps module patterns to PEFTAttributeConfig or dict.", | ||
validate_default=True, | ||
) | ||
|
||
adapter_type: str = ModeloptField( | ||
default="lora", | ||
title="Adapter type", | ||
description="Type of PEFT adapter to use. Currently only 'lora' is supported.", | ||
validate_default=True, | ||
) | ||
|
||
freeze_base_model: bool = ModeloptField( | ||
default=True, | ||
title="Freeze base weights during training", | ||
description="Whether to freeze the base model weights; in most cases, this should be set to True.", | ||
validate_default=True, | ||
) | ||
|
||
freeze_lora_weights: bool = ModeloptField( | ||
default=False, | ||
title="Freeze lora weights during training", | ||
description="Whether to freeze the lora model weights; in most cases, this should be set to False.", | ||
validate_default=True, | ||
) | ||
|
||
@field_validator("adapter_type") | ||
@classmethod | ||
def validate_adapter_type(cls, v): | ||
"""Validate adapter type.""" | ||
if v not in ["lora"]: | ||
raise ValueError(f"Unsupported adapter type: {v}. Only 'lora' is currently supported.") | ||
return v | ||
|
||
@field_validator("adapter_cfg") | ||
@classmethod | ||
def validate_adapter_cfg(cls, v): | ||
"""Validate and convert adapter configurations.""" | ||
validated_cfg = {} | ||
for key, value in v.items(): | ||
if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig): | ||
# Convert dict to PEFTAttributeConfig to trigger validation | ||
try: | ||
validated_cfg[key] = PEFTAttributeConfig(**value) | ||
except Exception as e: | ||
raise ValueError(f"Invalid adapter configuration for '{key}': {e}") | ||
else: | ||
validated_cfg[key] = value | ||
return validated_cfg | ||
|
||
|
||
class ExportPEFTConfig(ModeloptBaseConfig): | ||
"""An empty config.""" |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.