Skip to content

Commit 770c12c

Browse files
committed
Update missing inputs and imports.
1 parent 4491481 commit 770c12c

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2525
from ...utils.torch_utils import randn_tensor
2626
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2728

2829

2930
if is_torch_xla_available():
@@ -131,6 +132,12 @@ class AuraFlowPipeline(DiffusionPipeline):
131132

132133
_optional_components = []
133134
model_cpu_offload_seq = "text_encoder->transformer->vae"
135+
_callback_tensor_inputs = [
136+
"latents",
137+
"prompt_embeds",
138+
"add_text_embeds",
139+
"add_time_ids",
140+
]
134141

135142
def __init__(
136143
self,

src/diffusers/pipelines/lumina/pipeline_lumina.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from ...utils.torch_utils import randn_tensor
3939
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
40+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
4041

4142

4243
if is_torch_xla_available():

0 commit comments

Comments
 (0)