Skip to content

Commit 88aa7f6

Browse files
jinghuan-Chenpatrickvonplatenyiyixuxu
authored
Make LoRACompatibleConv padding_mode work. (#6031)
* Make LoRACompatibleConv padding_mode work. * Format code style. * add fast test * Update src/diffusers/models/lora.py Simplify the code by patrickvonplaten. Co-authored-by: Patrick von Platen <[email protected]> * code refactor * apply patrickvonplaten suggestion to simplify the code. * rm test_lora_layers_old_backend.py and add test case in test_lora_layers_peft.py * update test case. --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent ad310af commit 88aa7f6

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

src/diffusers/models/lora.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,19 @@ def _unfuse_lora(self):
361361
self.w_down = None
362362

363363
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
364+
if self.padding_mode != "zeros":
365+
hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode)
366+
padding = (0, 0)
367+
else:
368+
padding = self.padding
369+
370+
original_outputs = F.conv2d(
371+
hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups
372+
)
373+
364374
if self.lora_layer is None:
365-
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
366-
# see: https://github.com/huggingface/diffusers/pull/4315
367-
return F.conv2d(
368-
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
369-
)
375+
return original_outputs
370376
else:
371-
original_outputs = F.conv2d(
372-
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
373-
)
374377
return original_outputs + (scale * self.lora_layer(hidden_states))
375378

376379

tests/lora/test_lora_layers_peft.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,24 @@ def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
11771177
# Just makes sure it works..
11781178
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
11791179

1180+
def test_modify_padding_mode(self):
1181+
def set_pad_mode(network, mode="circular"):
1182+
for _, module in network.named_modules():
1183+
if isinstance(module, torch.nn.Conv2d):
1184+
module.padding_mode = mode
1185+
1186+
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
1187+
components, _, _ = self.get_dummy_components(scheduler_cls)
1188+
pipe = self.pipeline_class(**components)
1189+
pipe = pipe.to(self.torch_device)
1190+
pipe.set_progress_bar_config(disable=None)
1191+
_pad_mode = "circular"
1192+
set_pad_mode(pipe.vae, _pad_mode)
1193+
set_pad_mode(pipe.unet, _pad_mode)
1194+
1195+
_, _, inputs = self.get_dummy_inputs()
1196+
_ = pipe(**inputs).images
1197+
11801198

11811199
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
11821200
pipeline_class = StableDiffusionPipeline

0 commit comments

Comments
 (0)