Skip to content

Commit 8dd065b

Browse files
hlkya-r-r-o-w
andauthored
Apply suggestions from code review
Co-authored-by: Aryan <[email protected]>
1 parent b6b9b45 commit 8dd065b

File tree

1 file changed

+6
-30
lines changed

1 file changed

+6
-30
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -892,26 +892,14 @@ def forward(
892892
cur_encoder_hidden_states = torch.cat(
893893
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
894894
)
895-
if self.training and self.gradient_checkpointing:
896-
897-
def create_custom_forward(module, return_dict=None):
898-
def custom_forward(*inputs):
899-
if return_dict is not None:
900-
return module(*inputs, return_dict=return_dict)
901-
else:
902-
return module(*inputs)
903-
904-
return custom_forward
905-
906-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
907-
hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
908-
create_custom_forward(block),
895+
if torch.is_grad_enabled() and self.gradient_checkpointing:
896+
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
897+
block,
909898
hidden_states,
910899
image_tokens_masks,
911900
cur_encoder_hidden_states,
912901
adaln_input,
913902
image_rotary_emb,
914-
**ckpt_kwargs,
915903
)
916904
else:
917905
hidden_states, initial_encoder_hidden_states = block(
@@ -938,26 +926,14 @@ def custom_forward(*inputs):
938926
for bid, block in enumerate(self.single_stream_blocks):
939927
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
940928
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
941-
if self.training and self.gradient_checkpointing:
942-
943-
def create_custom_forward(module, return_dict=None):
944-
def custom_forward(*inputs):
945-
if return_dict is not None:
946-
return module(*inputs, return_dict=return_dict)
947-
else:
948-
return module(*inputs)
949-
950-
return custom_forward
951-
952-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
953-
hidden_states = torch.utils.checkpoint.checkpoint(
954-
create_custom_forward(block),
929+
if torch.is_grad_enabled() and self.gradient_checkpointing:
930+
hidden_states = self._gradient_checkpointing_func(
931+
block,
955932
hidden_states,
956933
image_tokens_masks,
957934
None,
958935
adaln_input,
959936
image_rotary_emb,
960-
**ckpt_kwargs,
961937
)
962938
else:
963939
hidden_states = block(

0 commit comments

Comments
 (0)