Skip to content

Commit 658a8de

Browse files
committed
Use torch.is_grad_enabled instead of self.training for gradient checkpoint check and remove references to deprecated scale/lora_scale
1 parent a1606b0 commit 658a8de

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

src/diffusers/models/unets/unet_2d_blocks.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def __init__(
736736
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
737737
hidden_states = self.resnets[0](hidden_states, temb)
738738
for attn, resnet in zip(self.attentions, self.resnets[1:]):
739-
if self.training and self.gradient_checkpointing:
739+
if torch.is_grad_enabled() and self.gradient_checkpointing:
740740

741741
def create_custom_forward(module, return_dict=None):
742742
def custom_forward(*inputs):
@@ -1155,7 +1155,7 @@ def forward(
11551155
output_states = ()
11561156

11571157
for resnet, attn in zip(self.resnets, self.attentions):
1158-
if self.training and self.gradient_checkpointing:
1158+
if torch.is_grad_enabled() and self.gradient_checkpointing:
11591159

11601160
def create_custom_forward(module, return_dict=None):
11611161
def custom_forward(*inputs):
@@ -1167,7 +1167,6 @@ def custom_forward(*inputs):
11671167
return custom_forward
11681168

11691169
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1170-
cross_attention_kwargs.update({"scale": lora_scale})
11711170
hidden_states = torch.utils.checkpoint.checkpoint(
11721171
create_custom_forward(resnet),
11731172
hidden_states,
@@ -1177,8 +1176,7 @@ def custom_forward(*inputs):
11771176
hidden_states = attn(hidden_states, **cross_attention_kwargs)
11781177
output_states = output_states + (hidden_states,)
11791178
else:
1180-
cross_attention_kwargs.update({"scale": lora_scale})
1181-
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1179+
hidden_states = resnet(hidden_states, temb)
11821180
hidden_states = attn(hidden_states, **cross_attention_kwargs)
11831181
output_states = output_states + (hidden_states,)
11841182

@@ -2402,8 +2400,8 @@ def __init__(
24022400
else:
24032401
self.upsamplers = None
24042402

2405-
self.resolution_idx = resolution_idx
24062403
self.gradient_checkpointing = False
2404+
self.resolution_idx = resolution_idx
24072405

24082406
def forward(
24092407
self,
@@ -2423,9 +2421,8 @@ def forward(
24232421
res_hidden_states = res_hidden_states_tuple[-1]
24242422
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
24252423
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2426-
cross_attention_kwargs = {"scale": scale}
24272424

2428-
if self.training and self.gradient_checkpointing:
2425+
if torch.is_grad_enabled() and self.gradient_checkpointing:
24292426

24302427
def create_custom_forward(module, return_dict=None):
24312428
def custom_forward(*inputs):
@@ -2443,10 +2440,10 @@ def custom_forward(*inputs):
24432440
temb,
24442441
**ckpt_kwargs,
24452442
)
2446-
hidden_states = attn(hidden_states, **cross_attention_kwargs)
2443+
hidden_states = attn(hidden_states)
24472444
else:
2448-
hidden_states = resnet(hidden_states, temb, scale=scale)
2449-
hidden_states = attn(hidden_states, **cross_attention_kwargs)
2445+
hidden_states = resnet(hidden_states, temb)
2446+
hidden_states = attn(hidden_states)
24502447

24512448
if self.upsamplers is not None:
24522449
for upsampler in self.upsamplers:

src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2228,7 +2228,7 @@ def __init__(
22282228
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
22292229
hidden_states = self.resnets[0](hidden_states, temb)
22302230
for attn, resnet in zip(self.attentions, self.resnets[1:]):
2231-
if self.training and self.gradient_checkpointing:
2231+
if torch.is_grad_enabled() and self.gradient_checkpointing:
22322232

22332233
def create_custom_forward(module, return_dict=None):
22342234
def custom_forward(*inputs):

0 commit comments

Comments
 (0)