- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
NPU Adaption for Sanna #10409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NPU Adaption for Sanna #10409
Changes from 1 commit
626e339
              1a72a00
              4d67c54
              a1965dd
              2c3b117
              326b98d
              715822f
              963e290
              510e1d6
              3d3aae3
              4cea819
              0d9e1b3
              2052049
              487dd1a
              cfbbb8f
              7b8ad74
              d7d54d8
              ad4beaa
              52d8c71
              4c1d56d
              63e3459
              d61d570
              ab2d71b
              a323229
              a456fb1
              fedfdd4
              7364276
              3add6de
              70cf529
              8f18aae
              feb8064
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -63,6 +63,7 @@ | |
| is_wandb_available, | ||
| ) | ||
| from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card | ||
| from diffusers.utils.import_utils import is_torch_npu_available | ||
| from diffusers.utils.torch_utils import is_compiled_module | ||
|  | ||
|  | ||
|  | @@ -74,6 +75,9 @@ | |
|  | ||
| logger = get_logger(__name__) | ||
|  | ||
| if is_torch_npu_available(): | ||
| torch.npu.config.allow_internal_format = False | ||
|  | ||
|  | ||
| def save_model_card( | ||
| repo_id: str, | ||
|  | @@ -920,8 +924,7 @@ def main(args): | |
| image.save(image_filename) | ||
|  | ||
| del pipeline | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| free_memory() | ||
|  | ||
| # Handle the repository creation | ||
| if accelerator.is_main_process: | ||
|  | @@ -979,10 +982,10 @@ def main(args): | |
| ) | ||
|  | ||
| # VAE should always be kept in fp32 for SANA (?) | ||
| vae.to(dtype=torch.float32) | ||
| vae.to(accelerator.device, dtype=torch.float32) | ||
| transformer.to(accelerator.device, dtype=weight_dtype) | ||
| # because Gemma2 is particularly suited for bfloat16. | ||
| text_encoder.to(dtype=torch.bfloat16) | ||
| text_encoder.to(accelerator.device, dtype=torch.bfloat16) | ||
|          | ||
|  | ||
| # Initialize a text encoding pipeline and keep it to CPU for now. | ||
| text_encoding_pipeline = SanaPipeline.from_pretrained( | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -19,11 +19,12 @@ | |
|  | ||
| from ...configuration_utils import ConfigMixin, register_to_config | ||
| from ...loaders import PeftAdapterMixin | ||
| from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers | ||
| from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available | ||
| from ..attention_processor import ( | ||
| Attention, | ||
| AttentionProcessor, | ||
| AttnProcessor2_0, | ||
| AttnProcessorNPU, | ||
| SanaLinearAttnProcessor2_0, | ||
| ) | ||
| from ..embeddings import PatchEmbed, PixArtAlphaTextProjection | ||
|  | @@ -119,6 +120,12 @@ def __init__( | |
| # 2. Cross Attention | ||
| if cross_attention_dim is not None: | ||
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) | ||
|  | ||
|          | ||
| if is_torch_npu_available(): | ||
|          | ||
| attn_processor = AttnProcessorNPU() | ||
| else: | ||
| attn_processor = AttnProcessor2_0() | ||
|  | ||
| self.attn2 = Attention( | ||
| query_dim=dim, | ||
| cross_attention_dim=cross_attention_dim, | ||
|  | @@ -127,7 +134,7 @@ def __init__( | |
| dropout=dropout, | ||
| bias=True, | ||
| out_bias=attention_out_bias, | ||
| processor=AttnProcessor2_0(), | ||
| processor=attn_processor, | ||
| ) | ||
|  | ||
| # 3. Feed-forward | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed. As we conditionally put the VAE on and off the accelerator device.