Skip to content

Commit 41751a3

Browse files
committed
update
1 parent 292469d commit 41751a3

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def _get_t5_prompt_embeds(
214214

215215
text_inputs = self.tokenizer(
216216
prompt,
217-
padding=True,
217+
padding="max_length",
218218
max_length=max_sequence_length,
219219
truncation=True,
220220
return_length=False,
@@ -286,6 +286,11 @@ def encode_prompt(
286286

287287
prompt = [prompt] if isinstance(prompt, str) else prompt
288288

289+
if prompt is not None:
290+
batch_size = len(prompt)
291+
else:
292+
batch_size = prompt_embeds.shape[0]
293+
289294
if prompt_embeds is None:
290295
prompt_embeds = self._get_t5_prompt_embeds(
291296
prompt=prompt,
@@ -295,6 +300,8 @@ def encode_prompt(
295300
)
296301

297302
if negative_prompt_embeds is None:
303+
negative_prompt = negative_prompt or ""
304+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
298305
negative_prompt_embeds = self._get_t5_prompt_embeds(
299306
prompt=negative_prompt,
300307
num_images_per_prompt=num_images_per_prompt,

tests/pipelines/chroma/test_pipeline_chroma.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,15 @@
44
import torch
55
from transformers import AutoTokenizer, T5EncoderModel
66

7-
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
8-
from diffusers.utils.testing_utils import torch_device
7+
from diffusers import (
8+
AutoencoderKL,
9+
ChromaPipeline,
10+
ChromaTransformer2DModel,
11+
FlowMatchEulerDiscreteScheduler,
12+
)
13+
from diffusers.utils.testing_utils import (
14+
torch_device,
15+
)
916

1017
from ..test_pipelines_common import (
1118
FluxIPAdapterTesterMixin,
@@ -22,9 +29,6 @@ class ChromaPipelineFastTests(
2229
):
2330
pipeline_class = ChromaPipeline
2431
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
25-
26-
pipeline_class = ChromaPipeline
27-
params = frozenset(["prompt", "negative_prompt", "height", "width", "guidance_scale", "prompt_embeds"])
2832
batch_params = frozenset(["prompt"])
2933

3034
# there is no xformers processor for Flux
@@ -39,14 +43,13 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
3943
in_channels=4,
4044
num_layers=num_layers,
4145
num_single_layers=num_single_layers,
42-
attention_head_dim=4,
43-
num_attention_heads=4,
46+
attention_head_dim=16,
47+
num_attention_heads=2,
4448
joint_attention_dim=32,
4549
axes_dims_rope=[4, 4, 8],
46-
approximator_in_factor=1,
4750
approximator_hidden_dim=32,
48-
approximator_out_dim=64,
49-
approximator_layers=5,
51+
approximator_layers=1,
52+
approximator_num_channels=16,
5053
)
5154

5255
torch.manual_seed(0)

0 commit comments

Comments
 (0)