Skip to content

Commit b402629

Browse files
authored
Merge branch 'main' into xpu-enabling
2 parents f3a519f + b572635 commit b402629

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

examples/community/rerender_a_video.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,17 @@
3030
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
3131
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3232
from diffusers.schedulers import KarrasDiffusionSchedulers
33-
from diffusers.utils import BaseOutput, deprecate, logging
33+
from diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging
3434
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
3535

3636

37+
if is_torch_xla_available():
38+
import torch_xla.core.xla_model as xm
39+
40+
XLA_AVAILABLE = True
41+
else:
42+
XLA_AVAILABLE = False
43+
3744
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3845

3946

@@ -775,7 +782,7 @@ def __call__(
775782
self.attn_state.reset()
776783

777784
# 4.1 prepare frames
778-
image = self.image_processor.preprocess(frames[0]).to(dtype=torch.float32)
785+
image = self.image_processor.preprocess(frames[0]).to(dtype=self.dtype)
779786
first_image = image[0] # C, H, W
780787

781788
# 4.2 Prepare controlnet_conditioning_image
@@ -919,8 +926,8 @@ def __call__(
919926
prev_image = frames[idx - 1]
920927
control_image = control_frames[idx]
921928
# 5.1 prepare frames
922-
image = self.image_processor.preprocess(image).to(dtype=torch.float32)
923-
prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
929+
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
930+
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
924931

925932
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
926933
self.flow_model, first_image, image[0], first_result, False, self.device
@@ -1100,6 +1107,9 @@ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None):
11001107
if callback is not None and i % callback_steps == 0:
11011108
callback(i, t, latents)
11021109

1110+
if XLA_AVAILABLE:
1111+
xm.mark_step()
1112+
11031113
return latents
11041114

11051115
if mask_start_t <= mask_end_t:

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,30 @@
1111
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
1212

1313
from ...image_processor import VaeImageProcessor
14-
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
14+
from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
1515
from ...models import AutoencoderKL, UNet2DConditionModel
1616
from ...models.lora import adjust_lora_scale_text_encoder
1717
from ...schedulers import KarrasDiffusionSchedulers
18-
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
18+
from ...utils import (
19+
USE_PEFT_BACKEND,
20+
BaseOutput,
21+
is_torch_xla_available,
22+
logging,
23+
scale_lora_layers,
24+
unscale_lora_layers,
25+
)
1926
from ...utils.torch_utils import randn_tensor
2027
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
2128
from ..stable_diffusion import StableDiffusionSafetyChecker
2229

2330

31+
if is_torch_xla_available():
32+
import torch_xla.core.xla_model as xm
33+
34+
XLA_AVAILABLE = True
35+
else:
36+
XLA_AVAILABLE = False
37+
2438
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2539

2640

@@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
282296

283297

284298
class TextToVideoZeroPipeline(
285-
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
299+
DiffusionPipeline,
300+
StableDiffusionMixin,
301+
TextualInversionLoaderMixin,
302+
StableDiffusionLoraLoaderMixin,
303+
FromSingleFileMixin,
286304
):
287305
r"""
288306
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
@@ -440,6 +458,10 @@ def backward_loop(
440458
if callback is not None and i % callback_steps == 0:
441459
step_idx = i // getattr(self.scheduler, "order", 1)
442460
callback(step_idx, t, latents)
461+
462+
if XLA_AVAILABLE:
463+
xm.mark_step()
464+
443465
return latents.clone().detach()
444466

445467
# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs

tests/lora/test_lora_layers_sd3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929
from diffusers.utils import load_image
3030
from diffusers.utils.import_utils import is_accelerate_available
3131
from diffusers.utils.testing_utils import (
32+
nightly,
3233
numpy_cosine_similarity_distance,
3334
require_peft_backend,
3435
require_torch_gpu,
36+
slow,
3537
torch_device,
3638
)
3739

@@ -126,6 +128,8 @@ def test_modify_padding_mode(self):
126128
pass
127129

128130

131+
@slow
132+
@nightly
129133
@require_torch_gpu
130134
@require_peft_backend
131135
class LoraSD3IntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)