Skip to content

Commit 565e51c

Browse files
committed
update testpipeline
1 parent 711dded commit 565e51c

File tree

4 files changed

+36
-122
lines changed

4 files changed

+36
-122
lines changed

src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class OmniGenPipeline(
139139

140140
model_cpu_offload_seq = "transformer->vae"
141141
_optional_components = []
142-
_callback_tensor_inputs = ["latents", "input_images_latents"]
142+
_callback_tensor_inputs = ["latents"]
143143

144144
def __init__(
145145
self,
@@ -435,6 +435,7 @@ def __call__(
435435
width=width,
436436
use_img_cfg=use_img_cfg,
437437
use_input_image_size_as_output=use_input_image_size_as_output,
438+
num_images_per_prompt=num_images_per_prompt,
438439
)
439440
processed_data["input_ids"] = processed_data["input_ids"].to(device)
440441
processed_data["attention_mask"] = processed_data["attention_mask"].to(device)
@@ -448,6 +449,7 @@ def __call__(
448449
timesteps, num_inference_steps = retrieve_timesteps(
449450
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
450451
)
452+
self._num_timesteps = len(timesteps)
451453

452454
# 6. Prepare latents.
453455
if use_input_image_size_as_output:
@@ -496,6 +498,14 @@ def __call__(
496498
latents_dtype = latents.dtype
497499
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
498500

501+
if callback_on_step_end is not None:
502+
callback_kwargs = {}
503+
for k in callback_on_step_end_tensor_inputs:
504+
callback_kwargs[k] = locals()[k]
505+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
506+
507+
latents = callback_outputs.pop("latents", latents)
508+
499509
if latents.dtype != latents_dtype:
500510
if torch.backends.mps.is_available():
501511
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272

src/diffusers/pipelines/omnigen/processor_omnigen.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __call__(
135135
use_img_cfg: bool = True,
136136
separate_cfg_input: bool = False,
137137
use_input_image_size_as_output: bool = False,
138+
num_images_per_prompt: int = 1,
138139
) -> Dict:
139140
if isinstance(instructions, str):
140141
instructions = [instructions]
@@ -161,17 +162,18 @@ def __call__(
161162
else:
162163
img_cfg_mllm_input = neg_mllm_input
163164

164-
if use_input_image_size_as_output:
165-
input_data.append(
166-
(
167-
mllm_input,
168-
neg_mllm_input,
169-
img_cfg_mllm_input,
170-
[mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)],
165+
for _ in range(num_images_per_prompt):
166+
if use_input_image_size_as_output:
167+
input_data.append(
168+
(
169+
mllm_input,
170+
neg_mllm_input,
171+
img_cfg_mllm_input,
172+
[mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)],
173+
)
171174
)
172-
)
173-
else:
174-
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
175+
else:
176+
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
175177

176178
return self.collator(input_data)
177179

src/diffusers/utils/testing_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,28 +1077,28 @@ def _is_torch_fp64_available(device):
10771077
# Function definitions
10781078
BACKEND_EMPTY_CACHE = {
10791079
"cuda": torch.cuda.empty_cache,
1080-
"xpu": torch.xpu.empty_cache,
1080+
# "xpu": torch.xpu.empty_cache,
10811081
"cpu": None,
10821082
"mps": torch.mps.empty_cache,
10831083
"default": None,
10841084
}
10851085
BACKEND_DEVICE_COUNT = {
10861086
"cuda": torch.cuda.device_count,
1087-
"xpu": torch.xpu.device_count,
1087+
# "xpu": torch.xpu.device_count,
10881088
"cpu": lambda: 0,
10891089
"mps": lambda: 0,
10901090
"default": 0,
10911091
}
10921092
BACKEND_MANUAL_SEED = {
10931093
"cuda": torch.cuda.manual_seed,
1094-
"xpu": torch.xpu.manual_seed,
1094+
# "xpu": torch.xpu.manual_seed,
10951095
"cpu": torch.manual_seed,
10961096
"mps": torch.mps.manual_seed,
10971097
"default": torch.manual_seed,
10981098
}
10991099
BACKEND_RESET_PEAK_MEMORY_STATS = {
11001100
"cuda": torch.cuda.reset_peak_memory_stats,
1101-
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
1101+
# "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
11021102
"cpu": None,
11031103
"mps": None,
11041104
"default": None,
@@ -1112,7 +1112,7 @@ def _is_torch_fp64_available(device):
11121112
}
11131113
BACKEND_MAX_MEMORY_ALLOCATED = {
11141114
"cuda": torch.cuda.max_memory_allocated,
1115-
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
1115+
# "xpu": getattr(torch.xpu, "max_memory_allocated", None),
11161116
"cpu": 0,
11171117
"mps": 0,
11181118
"default": 0,

tests/pipelines/omnigen/test_pipeline_omnigen.py

Lines changed: 8 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -34,112 +34,14 @@ def get_dummy_components(self):
3434
torch.manual_seed(0)
3535

3636
transformer = OmniGenTransformer2DModel(
37-
rope_scaling={
38-
"long_factor": [
39-
1.0299999713897705,
40-
1.0499999523162842,
41-
1.0499999523162842,
42-
1.0799999237060547,
43-
1.2299998998641968,
44-
1.2299998998641968,
45-
1.2999999523162842,
46-
1.4499999284744263,
47-
1.5999999046325684,
48-
1.6499998569488525,
49-
1.8999998569488525,
50-
2.859999895095825,
51-
3.68999981880188,
52-
5.419999599456787,
53-
5.489999771118164,
54-
5.489999771118164,
55-
9.09000015258789,
56-
11.579999923706055,
57-
15.65999984741211,
58-
15.769999504089355,
59-
15.789999961853027,
60-
18.360000610351562,
61-
21.989999771118164,
62-
23.079999923706055,
63-
30.009998321533203,
64-
32.35000228881836,
65-
32.590003967285156,
66-
35.56000518798828,
67-
39.95000457763672,
68-
53.840003967285156,
69-
56.20000457763672,
70-
57.95000457763672,
71-
59.29000473022461,
72-
59.77000427246094,
73-
59.920005798339844,
74-
61.190006256103516,
75-
61.96000671386719,
76-
62.50000762939453,
77-
63.3700065612793,
78-
63.48000717163086,
79-
63.48000717163086,
80-
63.66000747680664,
81-
63.850006103515625,
82-
64.08000946044922,
83-
64.760009765625,
84-
64.80001068115234,
85-
64.81001281738281,
86-
64.81001281738281,
87-
],
88-
"short_factor": [
89-
1.05,
90-
1.05,
91-
1.05,
92-
1.1,
93-
1.1,
94-
1.1,
95-
1.2500000000000002,
96-
1.2500000000000002,
97-
1.4000000000000004,
98-
1.4500000000000004,
99-
1.5500000000000005,
100-
1.8500000000000008,
101-
1.9000000000000008,
102-
2.000000000000001,
103-
2.000000000000001,
104-
2.000000000000001,
105-
2.000000000000001,
106-
2.000000000000001,
107-
2.000000000000001,
108-
2.000000000000001,
109-
2.000000000000001,
110-
2.000000000000001,
111-
2.000000000000001,
112-
2.000000000000001,
113-
2.000000000000001,
114-
2.000000000000001,
115-
2.000000000000001,
116-
2.000000000000001,
117-
2.000000000000001,
118-
2.000000000000001,
119-
2.000000000000001,
120-
2.000000000000001,
121-
2.1000000000000005,
122-
2.1000000000000005,
123-
2.2,
124-
2.3499999999999996,
125-
2.3499999999999996,
126-
2.3499999999999996,
127-
2.3499999999999996,
128-
2.3999999999999995,
129-
2.3999999999999995,
130-
2.6499999999999986,
131-
2.6999999999999984,
132-
2.8999999999999977,
133-
2.9499999999999975,
134-
3.049999999999997,
135-
3.049999999999997,
136-
3.049999999999997,
137-
],
138-
"type": "su",
139-
},
140-
patch_size=2,
37+
hidden_size=16,
38+
num_attention_heads=4,
39+
num_key_value_heads=4,
40+
intermediate_size=32,
41+
num_layers=1,
14142
in_channels=4,
142-
pos_embed_max_size=192,
43+
time_step_dim=4,
44+
rope_scaling={"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
14345
)
14446

14547
torch.manual_seed(0)
@@ -174,7 +76,7 @@ def get_dummy_inputs(self, device, seed=0):
17476
inputs = {
17577
"prompt": "A painting of a squirrel eating a burger",
17678
"generator": generator,
177-
"num_inference_steps": 2,
79+
"num_inference_steps": 1,
17880
"guidance_scale": 3.0,
17981
"output_type": "np",
18082
"height": 16,

0 commit comments

Comments
 (0)