Skip to content

Commit a8afedf

Browse files
committed
Latest SDXL commits
1 parent 6b58622 commit a8afedf

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ def collate_fn(examples):
829829
print(f" Total optimization steps = {args.max_train_steps}")
830830

831831
# unet = add_checkpoints(unet)
832-
832+
# import pdb; pdb.set_trace()
833833
trainer = TrainSD(
834834
weight_dtype=weight_dtype,
835835
device=device,

src/diffusers/models/attention_processor.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3510,7 +3510,7 @@ def __call__(
35103510
# if attn.group_norm is not None:
35113511
# hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
35123512

3513-
3513+
# batch_size = hidden_states.shape[0]
35143514

35153515
if encoder_hidden_states is None:
35163516
encoder_hidden_states = hidden_states
@@ -3528,10 +3528,10 @@ def __call__(
35283528
35293529
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35303530
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3531-
3531+
"""
35323532
assert attn.norm_q is None
35333533
assert attn.norm_k is None
3534-
"""
3534+
35353535
# if attn.norm_q is not None:
35363536
# query = attn.norm_q(query)
35373537
# if attn.norm_k is not None:
@@ -3557,26 +3557,27 @@ def __call__(
35573557
# # logger.warning(
35583558
# # "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
35593559
# # )
3560-
# hidden_states = self.scaled_dot_product_attention_compiled(
3561-
# query, key, value
3562-
# )
3560+
# hidden_states = self.scaled_dot_product_attention(
3561+
# query, key, value
3562+
# )
35633563

35643564
#*hidden_states = JaxFun.apply(query, key, value)
3565+
import pdb; pdb.set_trace()
35653566
hidden_states = JaxFun.apply(hidden_states, encoder_hidden_states, attn.to_q.weight, attn.to_k.weight, attn.to_v.weight, attn.heads)
35663567
hidden_states = hidden_states.to(input_dtype)
35673568

35683569
# linear proj
35693570
hidden_states = attn.to_out[0](hidden_states)
35703571
# dropout
3571-
hidden_states = attn.to_out[1](hidden_states)
3572+
# hidden_states = attn.to_out[1](hidden_states)
35723573

35733574
# if input_ndim == 4:
35743575
# hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
35753576

3576-
if attn.residual_connection:
3577-
hidden_states = hidden_states + residual
3577+
# if attn.residual_connection:
3578+
# hidden_states = hidden_states + residual
35783579

3579-
hidden_states = hidden_states / attn.rescale_output_factor
3580+
# hidden_states = hidden_states / attn.rescale_output_factor
35803581

35813582
return hidden_states
35823583

src/diffusers/models/transformers/transformer_2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ..modeling_outputs import Transformer2DModelOutput
2525
from ..modeling_utils import LegacyModelMixin
2626
from ..normalization import AdaLayerNormSingle
27-
27+
import torch_xla.debug.profiler as xp
2828

2929
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3030

@@ -321,6 +321,7 @@ def _init_patched_inputs(self, norm_type):
321321
in_features=self.caption_channels, hidden_size=self.inner_dim
322322
)
323323

324+
@xp.trace_me("Transformer2Dmodel")
324325
def forward(
325326
self,
326327
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)