Skip to content

Commit cab0dd8

Browse files
committed
test_flux_ip_adapter_inference
1 parent 9276ced commit cab0dd8

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,112 @@ def test_flux_inference(self):
298298
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
299299

300300
assert max_diff < 1e-4
301+
302+
303+
@slow
304+
@require_big_gpu_with_torch_cuda
305+
@pytest.mark.big_gpu_with_torch_cuda
306+
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
307+
pipeline_class = FluxPipeline
308+
repo_id = "black-forest-labs/FLUX.1-dev"
309+
image_encoder_pretrained_model_name_or_path = "openai/clip-vit-large-patch14"
310+
weight_name = "ip_adapter.safetensors"
311+
ip_adapter_repo_id = "XLabs-AI/flux-ip-adapter"
312+
313+
def setUp(self):
314+
super().setUp()
315+
gc.collect()
316+
torch.cuda.empty_cache()
317+
318+
def tearDown(self):
319+
super().tearDown()
320+
gc.collect()
321+
torch.cuda.empty_cache()
322+
323+
def get_inputs(self, device, seed=0):
324+
if str(device).startswith("mps"):
325+
generator = torch.manual_seed(seed)
326+
else:
327+
generator = torch.Generator(device="cpu").manual_seed(seed)
328+
329+
prompt_embeds = torch.load(
330+
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
331+
)
332+
pooled_prompt_embeds = torch.load(
333+
hf_hub_download(
334+
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
335+
)
336+
)
337+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
338+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
339+
ip_adapter_image = np.zeros((1024, 1024, 3), dtype=np.uint8)
340+
return {
341+
"prompt_embeds": prompt_embeds,
342+
"pooled_prompt_embeds": pooled_prompt_embeds,
343+
"negative_prompt_embeds": negative_prompt_embeds,
344+
"negative_pooled_prompt_embeds": negative_pooled_prompt_embeds,
345+
"ip_adapter_image": ip_adapter_image,
346+
"num_inference_steps": 2,
347+
"guidance_scale": 3.5,
348+
"true_cfg_scale": 4.0,
349+
"max_sequence_length": 256,
350+
"output_type": "np",
351+
"generator": generator,
352+
}
353+
354+
def test_flux_ip_adapter_inference(self):
355+
pipe = self.pipeline_class.from_pretrained(
356+
self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
357+
)
358+
pipe.load_ip_adapter(
359+
self.ip_adapter_repo_id,
360+
weight_name=self.weight_name,
361+
image_encoder_pretrained_model_name_or_path=self.image_encoder_pretrained_model_name_or_path,
362+
)
363+
pipe.set_ip_adapter_scale(1.0)
364+
pipe.enable_model_cpu_offload()
365+
366+
inputs = self.get_inputs(torch_device)
367+
368+
image = pipe(**inputs).images[0]
369+
image_slice = image[0, :10, :10]
370+
371+
expected_slice = np.array(
372+
[
373+
0.1855,
374+
0.1680,
375+
0.1406,
376+
0.1953,
377+
0.1699,
378+
0.1465,
379+
0.2012,
380+
0.1738,
381+
0.1484,
382+
0.2051,
383+
0.1797,
384+
0.1523,
385+
0.2012,
386+
0.1719,
387+
0.1445,
388+
0.2070,
389+
0.1777,
390+
0.1465,
391+
0.2090,
392+
0.1836,
393+
0.1484,
394+
0.2129,
395+
0.1875,
396+
0.1523,
397+
0.2090,
398+
0.1816,
399+
0.1484,
400+
0.2110,
401+
0.1836,
402+
0.1543,
403+
],
404+
dtype=np.float32,
405+
)
406+
407+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
408+
409+
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"

0 commit comments

Comments
 (0)