Skip to content

Commit a1a9bfb

Browse files
authored
[bugfix] fix bug in diffuers/pipeline/_init_ (mindspore-lab#982)
* bugfix * bugfix
1 parent 6425119 commit a1a9bfb

File tree

6 files changed

+23
-21
lines changed

6 files changed

+23
-21
lines changed

mindone/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@
525525
StableUnCLIPPipeline,
526526
StableVideoDiffusionPipeline,
527527
TextToVideoSDPipeline,
528+
TextToVideoZeroPipeline,
528529
TextToVideoZeroSDXLPipeline,
529530
UnCLIPImageVariationPipeline,
530531
UnCLIPPipeline,
@@ -535,7 +536,6 @@
535536
WuerstchenCombinedPipeline,
536537
WuerstchenDecoderPipeline,
537538
WuerstchenPriorPipeline,
538-
TextToVideoZeroPipeline,
539539
)
540540
from .schedulers import (
541541
AmusedScheduler,

mindone/diffusers/pipelines/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,10 @@
199199
"TextToVideoSDPipeline",
200200
"TextToVideoZeroSDXLPipeline",
201201
"VideoToVideoSDPipeline",
202+
"TextToVideoZeroPipeline",
202203
],
203204
"unclip": ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"],
204-
["unidiffuser"]: [
205+
"unidiffuser": [
205206
"ImageTextPipelineOutput",
206207
"UniDiffuserModel",
207208
"UniDiffuserPipeline",
@@ -212,7 +213,6 @@
212213
"WuerstchenDecoderPipeline",
213214
"WuerstchenPriorPipeline",
214215
],
215-
"text_to_video_synthesis" : ["TextToVideoZeroPipeline"],
216216
"pipeline_utils": [
217217
"AudioPipelineOutput",
218218
"DiffusionPipeline",
@@ -383,11 +383,15 @@
383383
)
384384
from .stable_video_diffusion import StableVideoDiffusionPipeline
385385
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
386-
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroSDXLPipeline, VideoToVideoSDPipeline
386+
from .text_to_video_synthesis import (
387+
TextToVideoSDPipeline,
388+
TextToVideoZeroPipeline,
389+
TextToVideoZeroSDXLPipeline,
390+
VideoToVideoSDPipeline,
391+
)
387392
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
388393
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
389394
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline
390-
from .text_to_video_synthesis import TextToVideoZeroPipeline
391395
else:
392396
import sys
393397

mindone/diffusers/pipelines/text_to_video_synthesis/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from ...utils import _LazyModule
44

5-
65
_import_structure = {}
76

87
_import_structure["pipeline_output"] = ["TextToVideoSDPipelineOutput"]

mindone/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33

44
import numpy as np
55
import PIL
6+
67
import mindspore as ms
78

8-
from ...utils import (
9-
BaseOutput,
10-
)
9+
from ...utils import BaseOutput
1110

1211

1312
@dataclass

mindone/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import copy
22
import inspect
33
from dataclasses import dataclass
4-
from typing import Callable, List, Optional, Union, Tuple
4+
from typing import Callable, List, Optional, Union
55

66
import numpy as np
77
import PIL.Image
8+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9+
810
import mindspore as ms
911
from mindspore import mint, ops
10-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
1112

1213
from ...image_processor import VaeImageProcessor
1314
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
@@ -18,7 +19,6 @@
1819
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
1920
from ..stable_diffusion import StableDiffusionSafetyChecker
2021

21-
2222
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2323

2424

@@ -226,7 +226,8 @@ def warp_single_latent(latent, reference_flow):
226226
# mint.nn.functional.grid_sample not support dtype float16.
227227
if latent.dtype == ms.float16:
228228
warped = mint.nn.functional.grid_sample(
229-
latent.to(ms.float32), coords_t0.to(ms.float32), mode="nearest", padding_mode="reflection").to(ms.float16)
229+
latent.to(ms.float32), coords_t0.to(ms.float32), mode="nearest", padding_mode="reflection"
230+
).to(ms.float16)
230231
else:
231232
warped = mint.nn.functional.grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection")
232233
return warped
@@ -466,7 +467,8 @@ def check_inputs(
466467
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
467468
):
468469
raise ValueError(
469-
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
470+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found \
471+
{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
470472
)
471473

472474
if prompt is not None and prompt_embeds is not None:
@@ -712,7 +714,7 @@ def __call__(
712714

713715
self.scheduler = scheduler_copy
714716
x_1k_0 = self.backward_loop(
715-
timesteps=timesteps[-t1 - 1:],
717+
timesteps=timesteps[-t1 - 1 :],
716718
prompt_embeds=prompt_embeds,
717719
latents=x_1k_t1,
718720
guidance_scale=guidance_scale,
@@ -846,9 +848,7 @@ def encode_prompt(
846848
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not mint.equal(
847849
text_input_ids, untruncated_ids
848850
):
849-
removed_text = self.tokenizer.batch_decode(
850-
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
851-
)
851+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
852852
logger.warning(
853853
"The following part of your input was truncated because CLIP can only handle sequences up to"
854854
f" {self.tokenizer.model_max_length} tokens: {removed_text}"

tests/diffusers_tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
@ddt
2323
@slow
2424
class StableTextToVideoZeroPipelineIntegrationTests(PipelineTesterMixin, unittest.TestCase):
25-
2625
@data(*test_cases)
2726
@unpack
2827
def test_text_to_video_zero(self, mode, dtype):
2928
ms.set_context(mode=mode)
3029
ms_dtype = getattr(ms, dtype)
3130

32-
pipe = TextToVideoZeroPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5",
33-
mindspore_dtype=ms_dtype)
31+
pipe = TextToVideoZeroPipeline.from_pretrained(
32+
"stable-diffusion-v1-5/stable-diffusion-v1-5", mindspore_dtype=ms_dtype
33+
)
3434
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
3535

3636
prompt = "A bear is playing a guitar on Times Square"

0 commit comments

Comments
 (0)