Skip to content

Commit 5b60154

Browse files
authored
ENH LoRA ConvNd layers using the groups argument. (#2403)
Conv layers with groups>1 are supported, but not merging.
1 parent 93eea9c commit 5b60154

File tree

2 files changed

+71
-8
lines changed

2 files changed

+71
-8
lines changed

src/peft/tuners/lora/layer.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,9 @@ def __init__(
10281028
super().__init__()
10291029
LoraLayer.__init__(self, base_layer)
10301030

1031+
if base_layer.groups > 1:
1032+
warnings.warn("LoRA adapter added to ConvNd layer with groups > 1. Merging is not supported.")
1033+
10311034
self._active_adapter = adapter_name
10321035
self._kernel_dim = base_layer.weight.dim()
10331036

@@ -1064,7 +1067,9 @@ def update_layer(
10641067
conv_layer = type(base_layer)
10651068
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
10661069
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False)
1067-
self.lora_B[adapter_name] = conv_layer(r, self.out_features, out_kernel, out_stride, bias=lora_bias)
1070+
self.lora_B[adapter_name] = conv_layer(
1071+
r, self.out_features // base_layer.groups, out_kernel, out_stride, bias=lora_bias
1072+
)
10681073
self.lora_bias[adapter_name] = lora_bias
10691074

10701075
if use_rslora:
@@ -1129,6 +1134,11 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
11291134
for active_adapter in adapter_names:
11301135
if active_adapter in self.lora_A.keys():
11311136
base_layer = self.get_base_layer()
1137+
1138+
if base_layer.groups > 1:
1139+
# https://github.com/huggingface/peft/pull/2403
1140+
raise NotImplementedError("Merging is not supported for _ConvNd layers with groups > 1!")
1141+
11321142
if safe_merge:
11331143
# Note that safe_merge will be slower than the normal merge
11341144
# because of the copy operation.
@@ -1246,13 +1256,12 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
12461256
3
12471257
) * self.scaling[adapter]
12481258
else:
1249-
output_tensor = (
1250-
self.conv_fn(
1251-
weight_A.transpose(0, 1),
1252-
weight_B,
1253-
).transpose(0, 1)
1254-
* self.scaling[adapter]
1255-
)
1259+
output_tensor = self.conv_fn(weight_A.transpose(0, 1), weight_B)
1260+
1261+
if self.get_base_layer().groups > 1:
1262+
output_tensor = output_tensor * self.scaling[adapter]
1263+
else:
1264+
output_tensor = output_tensor.transpose(0, 1) * self.scaling[adapter]
12561265

12571266
if cast_to_fp32:
12581267
output_tensor = output_tensor.to(dtype=dtype)

tests/test_custom_models.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@
115115
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
116116
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
117117
("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}),
118+
("Conv2d Groups LoRA", "Conv2dGroups", LoraConfig, {"target_modules": ["conv2d"]}),
119+
("Conv2d Groups LoRA with DoRA", "Conv2dGroups", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
118120
("Conv3d 1 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"]}),
119121
("Conv3d 2 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d", "lin0"]}),
120122
("Conv3d 1 LoRA with DoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"], "use_dora": True}),
@@ -903,6 +905,25 @@ def forward(self, X):
903905
return X
904906

905907

908+
class ModelConv2DGroups(nn.Module):
909+
def __init__(self):
910+
super().__init__()
911+
self.conv2d = nn.Conv2d(5, 5, 3, groups=5)
912+
self.relu = nn.ReLU()
913+
self.flat = nn.Flatten()
914+
self.lin0 = nn.Linear(5, 2)
915+
self.sm = nn.LogSoftmax(dim=-1)
916+
917+
def forward(self, X):
918+
X = X.float().reshape(-1, 5, 3, 3)
919+
X = self.conv2d(X)
920+
X = self.relu(X)
921+
X = self.flat(X)
922+
X = self.lin0(X)
923+
X = self.sm(X)
924+
return X
925+
926+
906927
class ModelConv3D(nn.Module):
907928
def __init__(self):
908929
super().__init__()
@@ -967,6 +988,9 @@ def from_pretrained(cls, model_id, torch_dtype=None):
967988
if model_id == "Conv2d":
968989
return ModelConv2D().to(torch_dtype)
969990

991+
if model_id == "Conv2dGroups":
992+
return ModelConv2DGroups().to(torch_dtype)
993+
970994
if model_id == "Conv3d":
971995
return ModelConv3D().to(torch_dtype)
972996

@@ -1038,6 +1062,12 @@ def test_load_multiple_adapters(self, test_name, model_id, config_cls, config_kw
10381062

10391063
@parameterized.expand(TEST_CASES)
10401064
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
1065+
# https://github.com/huggingface/peft/pull/2403
1066+
if model_id in ["Conv2dGroups"]:
1067+
pytest.skip(
1068+
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
1069+
)
1070+
10411071
config_kwargs = config_kwargs.copy()
10421072
if issubclass(config_cls, LoraConfig):
10431073
config_kwargs["init_lora_weights"] = False
@@ -1055,6 +1085,12 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
10551085

10561086
@parameterized.expand(TEST_CASES)
10571087
def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs):
1088+
# https://github.com/huggingface/peft/pull/2403
1089+
if model_id in ["Conv2dGroups"]:
1090+
pytest.skip(
1091+
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
1092+
)
1093+
10581094
config_kwargs = config_kwargs.copy()
10591095
if issubclass(config_cls, LoraConfig):
10601096
config_kwargs["init_lora_weights"] = False
@@ -1064,6 +1100,12 @@ def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs)
10641100

10651101
@parameterized.expand(TEST_CASES)
10661102
def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, config_kwargs):
1103+
# https://github.com/huggingface/peft/pull/2403
1104+
if model_id in ["Conv2dGroups"]:
1105+
pytest.skip(
1106+
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
1107+
)
1108+
10671109
# calling merge twice with the same arguments should not change the output
10681110
config_kwargs = config_kwargs.copy()
10691111
if issubclass(config_cls, LoraConfig):
@@ -1074,6 +1116,12 @@ def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, confi
10741116

10751117
@parameterized.expand(TEST_CASES)
10761118
def test_safe_merge(self, test_name, model_id, config_cls, config_kwargs):
1119+
# https://github.com/huggingface/peft/pull/2403
1120+
if model_id in ["Conv2dGroups"]:
1121+
pytest.skip(
1122+
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
1123+
)
1124+
10771125
# calling merge twice with the same arguments should not change the output
10781126
config_kwargs = config_kwargs.copy()
10791127
if issubclass(config_cls, LoraConfig):
@@ -1291,6 +1339,12 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
12911339

12921340
@parameterized.expand(TEST_CASES)
12931341
def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs):
1342+
# https://github.com/huggingface/peft/pull/2403
1343+
if model_id in ["Conv2dGroups"]:
1344+
pytest.skip(
1345+
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
1346+
)
1347+
12941348
# same as test_disable_adapters, but with merging
12951349
X = self.prepare_inputs_for_testing()
12961350
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)

0 commit comments

Comments
 (0)