Skip to content

Commit ade1059

Browse files
[Flux.1] improve pos embed for ascend npu by computing on npu (#12897)
* [Flux.1] improve pos embed for ascend npu by setting it back to npu computation. * [Flux.2] improve pos embed for ascend npu by setting it back to npu computation. * [LongCat-Image] improve pos embed for ascend npu by setting it back to npu computation. * [Ovis-Image] improve pos embed for ascend npu by setting it back to npu computation. * Remove unused import of is_torch_npu_available --------- Co-authored-by: zhangtao <[email protected]>
1 parent 41a6e86 commit ade1059

File tree

4 files changed

+9
-27
lines changed

4 files changed

+9
-27
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
2727
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
2828
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -717,11 +717,7 @@ def forward(
717717
img_ids = img_ids[0]
718718

719719
ids = torch.cat((txt_ids, img_ids), dim=0)
720-
if is_torch_npu_available():
721-
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
722-
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
723-
else:
724-
image_rotary_emb = self.pos_embed(ids)
720+
image_rotary_emb = self.pos_embed(ids)
725721

726722
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
727723
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
24+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2525
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
2626
from ..attention import AttentionMixin, AttentionModuleMixin
2727
from ..attention_dispatch import dispatch_attention_fn
@@ -835,14 +835,8 @@ def forward(
835835
if txt_ids.ndim == 3:
836836
txt_ids = txt_ids[0]
837837

838-
if is_torch_npu_available():
839-
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
840-
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
841-
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
842-
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
843-
else:
844-
image_rotary_emb = self.pos_embed(img_ids)
845-
text_rotary_emb = self.pos_embed(txt_ids)
838+
image_rotary_emb = self.pos_embed(img_ids)
839+
text_rotary_emb = self.pos_embed(txt_ids)
846840
concat_rotary_emb = (
847841
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
848842
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),

src/diffusers/models/transformers/transformer_longcat_image.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import is_torch_npu_available, logging
24+
from ...utils import logging
2525
from ...utils.torch_utils import maybe_allow_in_graph
2626
from ..attention import AttentionModuleMixin, FeedForward
2727
from ..attention_dispatch import dispatch_attention_fn
@@ -499,11 +499,7 @@ def forward(
499499
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
500500

501501
ids = torch.cat((txt_ids, img_ids), dim=0)
502-
if is_torch_npu_available():
503-
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
504-
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
505-
else:
506-
image_rotary_emb = self.pos_embed(ids)
502+
image_rotary_emb = self.pos_embed(ids)
507503

508504
for index_block, block in enumerate(self.transformer_blocks):
509505
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:

src/diffusers/models/transformers/transformer_ovis_image.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import is_torch_npu_available, logging
24+
from ...utils import logging
2525
from ...utils.torch_utils import maybe_allow_in_graph
2626
from ..attention import AttentionModuleMixin, FeedForward
2727
from ..attention_dispatch import dispatch_attention_fn
@@ -530,11 +530,7 @@ def forward(
530530
img_ids = img_ids[0]
531531

532532
ids = torch.cat((txt_ids, img_ids), dim=0)
533-
if is_torch_npu_available():
534-
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
535-
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
536-
else:
537-
image_rotary_emb = self.pos_embed(ids)
533+
image_rotary_emb = self.pos_embed(ids)
538534

539535
for index_block, block in enumerate(self.transformer_blocks):
540536
if torch.is_grad_enabled() and self.gradient_checkpointing:

0 commit comments

Comments
 (0)