Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def _megatron_global_adapters_info_all_pp_ranks(
if isinstance(adapter, ModuleDict):
adapter_name = local_param_name.removeprefix(local_base_prefix + ".adapter.").split(".")[0]
adapter = adapter[adapter_name]
input_is_parallel, _, _, _, base_linear_is_parallel = get_adapter_attributes_from_linear(to_wrap)
input_is_parallel, _, _, _, _, base_linear_is_parallel = get_adapter_attributes_from_linear(to_wrap)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like should return a dict or dataclass now, since there are many things and could potentially increase in the future.

I will let this in. if you can plz file another pr, otherwise i will change next week.

global_param_objects.append(
(
global_base_name,
Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/peft/canonical_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s
)

is_expert = is_expert_linear(full_name)
input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = (
input_is_parallel, in_features, out_features, disable_tp_comm, disable_sp_comm, base_linear_is_parallel = (
get_adapter_attributes_from_linear(m, is_expert=is_expert)
)

Expand All @@ -244,6 +244,7 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s
model_parallel_config=getattr(m, "config", None),
alpha=self.alpha,
is_expert=is_expert,
disable_tensor_parallel_comm=disable_tp_comm,
disable_sequence_parallel_comm=disable_sp_comm,
base_linear_is_parallel=base_linear_is_parallel,
)
Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/peft/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s

if (ans := self.match(m, name, prefix)) is not None:
(match, full_name) = ans
input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = (
input_is_parallel, in_features, out_features, disable_tp_comm, disable_sp_comm, base_linear_is_parallel = (
get_adapter_attributes_from_linear(m)
)
logger.info(f"Adding DoRA to: {full_name}")
Expand All @@ -109,6 +109,7 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s
dropout_position=self.dropout_position,
model_parallel_config=getattr(m, "config", None),
alpha=self.alpha,
disable_tensor_parallel_comm=disable_tp_comm,
disable_sequence_parallel_comm=disable_sp_comm,
base_linear_is_parallel=base_linear_is_parallel,
)
Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio
)

is_expert = is_expert_linear(full_name)
input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = (
input_is_parallel, in_features, out_features, disable_tp_comm, disable_sp_comm, base_linear_is_parallel = (
get_adapter_attributes_from_linear(module, is_expert=is_expert)
)

Expand Down Expand Up @@ -164,6 +164,7 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio
alpha=self.alpha,
is_expert=is_expert,
a2a_experimental=self.a2a_experimental,
disable_tensor_parallel_comm=disable_tp_comm,
disable_sequence_parallel_comm=disable_sp_comm,
base_linear_is_parallel=base_linear_is_parallel,
)
Expand Down
37 changes: 32 additions & 5 deletions src/megatron/bridge/peft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@
TERL = (TERowParallelLinear, TERowParallelGroupedLinear)


def get_adapter_attributes_from_linear(m: nn.Module, is_expert: bool = False) -> Tuple[bool, int, int, bool, bool]:
def get_adapter_attributes_from_linear(
m: nn.Module, is_expert: bool = False
) -> Tuple[bool, int, int, bool, bool, bool]:
"""Returns attributes from the base layer.

input_is_parallel, in_features, out_features, disable_sequence_parallel_comm, base_linear_is_parallel
input_is_parallel, in_features, out_features, disable_tensor_parallel_comm, disable_sequence_parallel_comm, base_linear_is_parallel

This function analyzes a linear module and extracts key attributes needed for adapter configuration,
particularly for PEFT adapters in distributed training scenarios.
Expand All @@ -82,6 +84,7 @@ def get_adapter_attributes_from_linear(m: nn.Module, is_expert: bool = False) ->
- input_is_parallel: Whether the input is already parallelized
- in_features: Input feature dimension
- out_features: Output feature dimension
- disable_tensor_parallel_comm: Whether to disable tensor parallel communication
- disable_sequence_parallel_comm: Whether to disable sequence parallel communication
- base_linear_is_parallel: Whether the base linear layer uses parallelization

Expand All @@ -90,6 +93,16 @@ def get_adapter_attributes_from_linear(m: nn.Module, is_expert: bool = False) ->
"""
disable_sequence_parallel_comm = not m.config.sequence_parallel
base_linear_is_parallel = True

# In some modules (notably MoE shared_experts when moe_shared_expert_overlap is enabled),
# Megatron disables TP-related communications on the base linear layer by
# setting `parallel_mode=None` (TE) or `explicit_expert_comm=True` (legacy).
# https://github.com/NVIDIA/Megatron-LM/blob/5b1ef0703184299fbf71f6131bf2f9a5331e7238/megatron/core/transformer/moe/shared_experts.py#L95-L104
# The weights are still TP-sharded though, so we must keep using the real TP size
disable_tensor_parallel_comm = getattr(m, "parallel_mode", "") is None or getattr(m, "explicit_expert_comm", False)
if disable_tensor_parallel_comm:
disable_sequence_parallel_comm = True

if is_expert:
tp_size = parallel_state.get_expert_tensor_parallel_world_size()
else:
Expand Down Expand Up @@ -142,7 +155,14 @@ def get_adapter_attributes_from_linear(m: nn.Module, is_expert: bool = False) ->
else:
raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}")

return input_is_parallel, in_features, out_features, disable_sequence_parallel_comm, base_linear_is_parallel
return (
input_is_parallel,
in_features,
out_features,
disable_tensor_parallel_comm,
disable_sequence_parallel_comm,
base_linear_is_parallel,
)


def is_expert_linear(fqn: str) -> bool:
Expand All @@ -163,7 +183,7 @@ def is_expert_linear(fqn: str) -> bool:
>>> is_expert_linear("model.layers.0.mlp.linear_fc1")
False
"""
return re.match(r".*mlp\..*experts.*\.linear_fc[1-2]$", fqn) is not None
return re.match(r".*mlp\..*experts.*\.linear_fc[1-2]$", fqn) is not None and not ".shared_experts." in fqn


def wildcard_match(pattern: str, key: Optional[str]) -> Optional[bool]:
Expand Down Expand Up @@ -389,6 +409,7 @@ def __init__(
dropout_position: str = "post",
a2a_experimental: bool = False,
is_expert: bool = False,
disable_tensor_parallel_comm: bool = False,
disable_sequence_parallel_comm: bool = True,
base_linear_is_parallel: bool = True,
**kwargs,
Expand All @@ -410,6 +431,7 @@ def __init__(
dropout_position: When to apply dropout.
a2a_experimental: Use experimental all-to-all communication.
is_expert: Whether for expert layers in MoE.
disable_tensor_parallel_comm: Disable tensor parallel communication.
disable_sequence_parallel_comm: Disable sequence parallel communication.
dropout_recompute: Use recomputation for dropout.
**kwargs: Additional keyword arguments.
Expand Down Expand Up @@ -466,7 +488,12 @@ def __init__(
# if the original column parallel layer uses gather_output=False,
# then we will use the self.liner_out layer defined below.
lin_out_gather_output = True if input_is_parallel else False
if self.use_a2a and input_is_parallel and _sequence_parallel:
if (
self.use_a2a
and input_is_parallel
and _sequence_parallel
or (disable_tensor_parallel_comm and not input_is_parallel)
):
lin_out_gather_output = False

if not base_linear_is_parallel:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/models/test_model_bridge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def named_parameters(self):
)
monkeypatch.setattr(
"megatron.bridge.models.conversion.model_bridge.get_adapter_attributes_from_linear",
lambda *_args, **_kwargs: (True, None, None, None, False),
lambda *_args, **_kwargs: (True, None, None, None, None, False),
)
monkeypatch.setattr(
"torch.distributed.all_gather_object",
Expand Down
12 changes: 6 additions & 6 deletions tests/unit_tests/peft/test_canonical_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ def test_canonical_lora_transform_fused_layers(self):
def mock_get_attrs(module, is_expert=False):
if hasattr(module, "out_features"):
if module.out_features == 1536: # linear_qkv
return (False, 512, 1536, False, True)
return (False, 512, 1536, False, True, True)
elif module.out_features == 2048: # linear_fc1
return (False, 512, 2048, False, True)
return (False, 512, 512, False, True) # default
return (False, 512, 2048, False, True, True)
return (False, 512, 512, False, True, True) # default

with patch(
"megatron.bridge.peft.canonical_lora.get_adapter_attributes_from_linear", side_effect=mock_get_attrs
Expand Down Expand Up @@ -487,7 +487,7 @@ def test_canonical_lora_transform_idempotent_fused_layers(self):

# Mock the get_adapter_attributes_from_linear function
with patch("megatron.bridge.peft.canonical_lora.get_adapter_attributes_from_linear") as mock_get_attrs:
mock_get_attrs.return_value = (False, 512, 1536, False, True)
mock_get_attrs.return_value = (False, 512, 1536, False, True, True)

# Mock ParallelLinearAdapter
with patch("megatron.bridge.peft.canonical_lora.ParallelLinearAdapter") as mock_adapter:
Expand Down Expand Up @@ -525,7 +525,7 @@ def test_megatron_style_qkv_transform(self):

# Mock the get_adapter_attributes_from_linear function
with patch("megatron.bridge.peft.canonical_lora.get_adapter_attributes_from_linear") as mock_get_attrs:
mock_get_attrs.return_value = (False, 512, 1536, False, True)
mock_get_attrs.return_value = (False, 512, 1536, False, True, True)

# Mock ParallelLinearAdapter
with patch("megatron.bridge.peft.canonical_lora.ParallelLinearAdapter") as mock_adapter:
Expand All @@ -547,7 +547,7 @@ def test_megatron_style_fc1_transform(self):

# Mock the get_adapter_attributes_from_linear function
with patch("megatron.bridge.peft.canonical_lora.get_adapter_attributes_from_linear") as mock_get_attrs:
mock_get_attrs.return_value = (False, 512, 2048, False, True)
mock_get_attrs.return_value = (False, 512, 2048, False, True, True)

# Mock ParallelLinearAdapter
with patch("megatron.bridge.peft.canonical_lora.ParallelLinearAdapter") as mock_adapter:
Expand Down
37 changes: 30 additions & 7 deletions tests/unit_tests/peft/test_dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,10 @@ def test_transform_matched_module(
False,
512,
256,
False,
True,
True,
) # input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel
) # input_is_parallel, in_features, out_features, disable_tp_comm, disable_sp_comm, base_linear_is_parallel

# Create test module with config
test_module = nn.Linear(512, 256)
Expand Down Expand Up @@ -211,7 +212,7 @@ def mock_gather_func(tensor):

mock_gather.side_effect = mock_gather_func

mock_get_attributes.return_value = (True, 256, 128, False, True)
mock_get_attributes.return_value = (True, 256, 128, False, False, True)

test_module = nn.Linear(256, 128)
test_module.config = MockModelParallelConfig()
Expand Down Expand Up @@ -258,7 +259,14 @@ def test_transform_with_simple_model(self, mock_get_attributes, mock_row_linear,

# Mock get_adapter_attributes_from_linear to return appropriate values
def mock_get_attributes_func(module):
return (False, module.in_features, module.out_features, False, True)
return (
False,
module.in_features,
module.out_features,
False,
not module.config.sequence_parallel,
True,
)

mock_get_attributes.side_effect = mock_get_attributes_func

Expand Down Expand Up @@ -306,7 +314,14 @@ def test_full_model_application(self, mock_get_attributes, mock_row_linear, mock

# Mock get_adapter_attributes_from_linear to return appropriate values
def mock_get_attributes_func(module):
return (False, module.in_features, module.out_features, False, True)
return (
False,
module.in_features,
module.out_features,
False,
not module.config.sequence_parallel,
True,
)

mock_get_attributes.side_effect = mock_get_attributes_func

Expand Down Expand Up @@ -352,7 +367,8 @@ def test_wildcard_matching(self, mock_get_attributes, mock_row_linear, mock_col_

# Should match with wildcard
with patch(
"megatron.bridge.peft.dora.get_adapter_attributes_from_linear", return_value=(False, 10, 10, False, True)
"megatron.bridge.peft.dora.get_adapter_attributes_from_linear",
return_value=(False, 10, 10, False, False, True),
):
with patch("megatron.bridge.peft.dora_layers.DoRALinear._get_weight_norm", return_value=torch.randn(10)):
result = dora.transform(test_module, name="linear_qkv", prefix="layer.0.attention")
Expand All @@ -376,7 +392,7 @@ def test_dora_transform_idempotent(self):

# Mock all the necessary functions for DoRA transform
with patch("megatron.bridge.peft.dora.get_adapter_attributes_from_linear") as mock_get_attrs:
mock_get_attrs.return_value = (False, 512, 256, False, True)
mock_get_attrs.return_value = (False, 512, 256, False, True, True)

with patch("megatron.bridge.peft.utils.ColumnParallelLinear") as mock_col_linear:
# Create mocks for the adapters that will be created
Expand Down Expand Up @@ -416,7 +432,14 @@ def test_dora_transform_idempotent_full_model(self):
with patch("megatron.bridge.peft.dora.get_adapter_attributes_from_linear") as mock_get_attrs:

def mock_get_attributes_func(module):
return (False, module.in_features, module.out_features, False, True)
return (
False,
module.in_features,
module.out_features,
False,
not module.config.sequence_parallel,
True,
)

mock_get_attrs.side_effect = mock_get_attributes_func

Expand Down
40 changes: 29 additions & 11 deletions tests/unit_tests/peft/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,19 @@ def test_get_adapter_attributes_column_parallel(self, mock_parallel_state):
mock_parallel_state.get_tensor_model_parallel_world_size.return_value = 1
linear = MockColumnParallelLinear(input_size=100, output_size=50)

input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = (
get_adapter_attributes_from_linear(linear)
)
(
input_is_parallel,
in_features,
out_features,
disable_tp_comm,
disable_sp_comm,
base_linear_is_parallel,
) = get_adapter_attributes_from_linear(linear)

assert not input_is_parallel
assert in_features == 100
assert out_features == 50
assert not disable_tp_comm
assert disable_sp_comm # Should be True when sequence_parallel is False
assert base_linear_is_parallel # Should be True for parallel linear layers

Expand All @@ -292,13 +298,19 @@ def test_get_adapter_attributes_row_parallel(self, mock_parallel_state):
mock_parallel_state.get_tensor_model_parallel_world_size.return_value = 1
linear = MockRowParallelLinear(input_size=100, output_size=50)

input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = (
get_adapter_attributes_from_linear(linear)
)
(
input_is_parallel,
in_features,
out_features,
disable_tp_comm,
disable_sp_comm,
base_linear_is_parallel,
) = get_adapter_attributes_from_linear(linear)

assert input_is_parallel
assert in_features == 100
assert out_features == 50
assert not disable_tp_comm
assert disable_sp_comm
assert base_linear_is_parallel # Should be True for parallel linear layers

Expand All @@ -309,10 +321,16 @@ def test_get_adapter_attributes_sequence_parallel(self, mock_parallel_state):
linear = MockColumnParallelLinear(input_size=100, output_size=50)
linear.config.sequence_parallel = True

input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = (
get_adapter_attributes_from_linear(linear)
)
(
input_is_parallel,
in_features,
out_features,
disable_tp_comm,
disable_sp_comm,
base_linear_is_parallel,
) = get_adapter_attributes_from_linear(linear)

assert not disable_tp_comm
assert not disable_sp_comm # Should be False when sequence_parallel is True
assert base_linear_is_parallel # Should be True for parallel linear layers

Expand All @@ -332,12 +350,12 @@ def test_get_adapter_attributes_base_linear_is_parallel_flag(self, mock_parallel
mock_parallel_state.get_tensor_model_parallel_world_size.return_value = 1
# Test with ColumnParallelLinear - should return True for base_linear_is_parallel
column_linear = MockColumnParallelLinear(input_size=100, output_size=50)
_, _, _, _, base_linear_is_parallel = get_adapter_attributes_from_linear(column_linear)
_, _, _, _, _, base_linear_is_parallel = get_adapter_attributes_from_linear(column_linear)
assert base_linear_is_parallel # Should be True for parallel linear layers

# Test with RowParallelLinear - should return True for base_linear_is_parallel
row_linear = MockRowParallelLinear(input_size=100, output_size=50)
_, _, _, _, base_linear_is_parallel = get_adapter_attributes_from_linear(row_linear)
_, _, _, _, _, base_linear_is_parallel = get_adapter_attributes_from_linear(row_linear)
assert base_linear_is_parallel # Should be True for parallel linear layers


Expand Down