@@ -67,31 +67,13 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
6767 attention_head_dim = 16 ,
6868 num_attention_heads = 2 ,
6969 joint_attention_dim = 32 ,
70- pooled_projection_dim = 32 ,
7170 axes_dims_rope = [4 , 4 , 8 ],
7271 )
73- clip_text_encoder_config = CLIPTextConfig (
74- bos_token_id = 0 ,
75- eos_token_id = 2 ,
76- hidden_size = 32 ,
77- intermediate_size = 37 ,
78- layer_norm_eps = 1e-05 ,
79- num_attention_heads = 4 ,
80- num_hidden_layers = 5 ,
81- pad_token_id = 1 ,
82- vocab_size = 1000 ,
83- hidden_act = "gelu" ,
84- projection_dim = 32 ,
85- )
86-
87- torch .manual_seed (0 )
88- text_encoder = CLIPTextModel (clip_text_encoder_config )
8972
9073 torch .manual_seed (0 )
91- text_encoder_2 = T5EncoderModel .from_pretrained ("hf-internal-testing/tiny-random-t5" )
74+ text_encoder = T5EncoderModel .from_pretrained ("hf-internal-testing/tiny-random-t5" )
9275
93- tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
94- tokenizer_2 = AutoTokenizer .from_pretrained ("hf-internal-testing/tiny-random-t5" )
76+ tokenizer = AutoTokenizer .from_pretrained ("hf-internal-testing/tiny-random-t5" )
9577
9678 torch .manual_seed (0 )
9779 vae = AutoencoderKL (
@@ -113,7 +95,6 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
11395 return {
11496 "scheduler" : scheduler ,
11597 "text_encoder" : text_encoder ,
116- "text_encoder_2" : text_encoder_2 ,
11798 "tokenizer" : tokenizer ,
11899 "tokenizer_2" : tokenizer_2 ,
119100 "transformer" : transformer ,
@@ -130,6 +111,7 @@ def get_dummy_inputs(self, device, seed=0):
130111
131112 inputs = {
132113 "prompt" : "A painting of a squirrel eating a burger" ,
114+ "negative_prompt" : "bad, ugly" ,
133115 "generator" : generator ,
134116 "num_inference_steps" : 2 ,
135117 "guidance_scale" : 5.0 ,
@@ -140,14 +122,14 @@ def get_dummy_inputs(self, device, seed=0):
140122 }
141123 return inputs
142124
143- def test_flux_different_prompts (self ):
125+ def test_chroma_different_prompts (self ):
144126 pipe = self .pipeline_class (** self .get_dummy_components ()).to (torch_device )
145127
146128 inputs = self .get_dummy_inputs (torch_device )
147129 output_same_prompt = pipe (** inputs ).images [0 ]
148130
149131 inputs = self .get_dummy_inputs (torch_device )
150- inputs ["prompt_2 " ] = "a different prompt"
132+ inputs ["prompt " ] = "a different prompt"
151133 output_different_prompts = pipe (** inputs ).images [0 ]
152134
153135 max_diff = np .abs (output_same_prompt - output_different_prompts ).max ()
@@ -196,7 +178,7 @@ def test_fused_qkv_projections(self):
196178 "Original outputs should match when fused QKV projections are disabled."
197179 )
198180
199- def test_flux_image_output_shape (self ):
181+ def test_chroma_image_output_shape (self ):
200182 pipe = self .pipeline_class (** self .get_dummy_components ()).to (torch_device )
201183 inputs = self .get_dummy_inputs (torch_device )
202184
@@ -210,13 +192,3 @@ def test_flux_image_output_shape(self):
210192 output_height , output_width , _ = image .shape
211193 assert (output_height , output_width ) == (expected_height , expected_width )
212194
213- def test_flux_true_cfg (self ):
214- pipe = self .pipeline_class (** self .get_dummy_components ()).to (torch_device )
215- inputs = self .get_dummy_inputs (torch_device )
216- inputs .pop ("generator" )
217-
218- no_true_cfg_out = pipe (** inputs , generator = torch .manual_seed (0 )).images [0 ]
219- inputs ["negative_prompt" ] = "bad quality"
220- inputs ["true_cfg_scale" ] = 2.0
221- true_cfg_out = pipe (** inputs , generator = torch .manual_seed (0 )).images [0 ]
222- assert not np .allclose (no_true_cfg_out , true_cfg_out )
0 commit comments