Skip to content

Commit c6ff80c

Browse files
committed
Add inference code
1 parent 0b95ec0 commit c6ff80c

File tree

5 files changed

+65
-11
lines changed

5 files changed

+65
-11
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
import os
3+
import sys
4+
import numpy as np
5+
6+
import torch_xla.core.xla_model as xm
7+
from time import time
8+
from diffusers import StableDiffusionXLPipeline
9+
import torch_xla.runtime as xr
10+
11+
CACHE_DIR = os.environ.get("CACHE_DIR", '/mnt/bbahl/xla_cache/')
12+
if CACHE_DIR:
13+
xr.initialize_cache(CACHE_DIR, readonly=False)
14+
15+
16+
device = xm.xla_device()
17+
model_path = "/mnt/bbahl/trained-model"
18+
pipe = StableDiffusionXLPipeline.from_pretrained(
19+
model_path,
20+
torch_dtype=torch.bfloat16
21+
)
22+
pipe.to(device)
23+
prompt = ["A naruto with green eyes and red legs."]
24+
25+
pipe.unet.enable_xla_attention()
26+
# pipe.vae.enable_xla_attention()
27+
start = time()
28+
print("compiling...")
29+
import pdb; pdb.set_trace()
30+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
31+
print(f"compile time: {time() - start}")
32+
print("generate...")
33+
start = time()
34+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
35+
print(f"generation time (after compile) : {time() - start}")
36+
image.save("naruto.png")

src/diffusers/models/attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import torch
1717
import torch.nn.functional as F
1818
from torch import nn
19-
import torch_xla.debug.profiler as xp
2019

2120
from ..utils import deprecate, logging
2221
from ..utils.torch_utils import maybe_allow_in_graph

src/diffusers/models/attention_processor.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414
import inspect
1515
import math
1616
from typing import Callable, List, Optional, Tuple, Union
17-
import functools
1817
import torch
1918
import torch.nn.functional as F
2019
from torch import nn
21-
import torch_xla.debug.profiler as xp
2220
from ..image_processor import IPAdapterMaskProcessor
2321
from ..utils import deprecate, is_torch_xla_available, logging
2422
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
@@ -3239,6 +3237,24 @@ def __call__(
32393237

32403238
return hidden_states
32413239

3240+
def scaled_dot_product_attention(q, k, v):
3241+
"""
3242+
Compute the attention weights and output using scaled dot-product attention.
3243+
3244+
Args:
3245+
q (`torch.Tensor`):
3246+
Query tensor of shape (batch_size, num_heads, seq_length, head_dim).
3247+
k (`torch.Tensor`):
3248+
Key tensor of shape (batch_size, num_heads, seq_length, head_dim).
3249+
v (`torch.Tensor`):
3250+
Value tensor of shape (batch_size, num_heads, seq_length, head_dim).
3251+
3252+
Returns:
3253+
`torch.Tensor`: The output tensor after applying attention.
3254+
"""
3255+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
3256+
attn_weights = F.softmax(attn_weights, dim=-1)
3257+
return torch.matmul(attn_weights, v)
32423258

32433259
class AttnProcessor2_0:
32443260
r"""
@@ -3268,7 +3284,6 @@ def __call__(
32683284
hidden_states = attn.spatial_norm(hidden_states, temb)
32693285

32703286
input_ndim = hidden_states.ndim
3271-
32723287
if input_ndim == 4:
32733288
batch_size, channel, height, width = hidden_states.shape
32743289
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
@@ -3311,8 +3326,8 @@ def __call__(
33113326

33123327
# the output of sdp = (batch, num_heads, seq_len, head_dim)
33133328
# TODO: add support for attn.scale when we move to Torch 2.1
3314-
hidden_states = F.scaled_dot_product_attention(
3315-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3329+
hidden_states = scaled_dot_product_attention(
3330+
query, key, value
33163331
)
33173332

33183333
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -3464,6 +3479,10 @@ def __call__(
34643479
*args,
34653480
**kwargs,
34663481
) -> torch.Tensor:
3482+
input_ndim = hidden_states.ndim
3483+
if input_ndim == 4:
3484+
batch_size, channel, height, width = hidden_states.shape
3485+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
34673486
if encoder_hidden_states is None:
34683487
encoder_hidden_states = hidden_states
34693488
hidden_states = CrossAttention.apply(hidden_states, encoder_hidden_states, attn.to_q.weight, attn.to_k.weight, attn.to_v.weight, attn.heads)

src/diffusers/models/transformers/transformer_2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ..modeling_outputs import Transformer2DModelOutput
2525
from ..modeling_utils import LegacyModelMixin
2626
from ..normalization import AdaLayerNormSingle
27-
import torch_xla.debug.profiler as xp
2827

2928
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3029

@@ -321,7 +320,6 @@ def _init_patched_inputs(self, norm_type):
321320
in_features=self.caption_channels, hidden_size=self.inner_dim
322321
)
323322

324-
@xp.trace_me("Transformer2Dmodel")
325323
def forward(
326324
self,
327325
hidden_states: torch.Tensor,

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,7 @@ def __call__(
11091109

11101110
# 4. Prepare timesteps
11111111
timesteps, num_inference_steps = retrieve_timesteps(
1112-
self.scheduler, num_inference_steps, device, timesteps, sigmas
1112+
self.scheduler, num_inference_steps, 'cpu', timesteps, sigmas
11131113
)
11141114

11151115
# 5. Prepare latent variables
@@ -1209,8 +1209,9 @@ def __call__(
12091209

12101210
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
12111211

1212+
prompt_embeds = prompt_embeds.to(dtype=torch.bfloat16)
12121213
# predict the noise residual
1213-
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1214+
added_cond_kwargs = {"text_embeds": add_text_embeds.to(dtype=torch.bfloat16), "time_ids": add_time_ids.to(dtype=torch.bfloat16)}
12141215
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
12151216
added_cond_kwargs["image_embeds"] = image_embeds
12161217
noise_pred = self.unet(
@@ -1295,7 +1296,8 @@ def __call__(
12951296
self.vae.to(dtype=torch.float16)
12961297
else:
12971298
image = latents
1298-
1299+
xm.mark_step()
1300+
image = image.to('cpu')
12991301
if not output_type == "latent":
13001302
# apply watermark if available
13011303
if self.watermark is not None:

0 commit comments

Comments
 (0)