Skip to content

Commit 16b6583

Browse files
committed
allow input_fields as input & update message
1 parent f552773 commit 16b6583

8 files changed

+51
-23
lines changed

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING
16+
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
1717

1818
import torch
1919

@@ -73,14 +73,18 @@ def __init__(
7373
self.use_original_formulation = use_original_formulation
7474
self.momentum_buffer = None
7575

76-
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
76+
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
77+
78+
if input_fields is None:
79+
input_fields = self._input_fields
80+
7781
if self._step == 0:
7882
if self.adaptive_projected_guidance_momentum is not None:
7983
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
8084
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
8185
data_batches = []
8286
for i in range(self.num_conditions):
83-
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
87+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
8488
data_batches.append(data_batch)
8589
return data_batches
8690

src/diffusers/guiders/auto_guidance.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import List, Optional, Union, TYPE_CHECKING
16+
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
1717

1818
import torch
1919

@@ -120,11 +120,15 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
120120
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
121121
registry.remove_hook(name, recurse=True)
122122

123-
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
123+
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
124+
125+
if input_fields is None:
126+
input_fields = self._input_fields
127+
124128
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
125129
data_batches = []
126130
for i in range(self.num_conditions):
127-
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
131+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
128132
data_batches.append(data_batch)
129133
return data_batches
130134

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING
16+
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
1717

1818
import torch
1919

@@ -75,11 +75,15 @@ def __init__(
7575
self.guidance_rescale = guidance_rescale
7676
self.use_original_formulation = use_original_formulation
7777

78-
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
78+
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
79+
80+
if input_fields is None:
81+
input_fields = self._input_fields
82+
7983
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
8084
data_batches = []
8185
for i in range(self.num_conditions):
82-
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
86+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
8387
data_batches.append(data_batch)
8488
return data_batches
8589

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING
16+
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
1717

1818
import torch
1919

@@ -73,11 +73,15 @@ def __init__(
7373
self.guidance_rescale = guidance_rescale
7474
self.use_original_formulation = use_original_formulation
7575

76-
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
76+
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
77+
78+
if input_fields is None:
79+
input_fields = self._input_fields
80+
7781
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
7882
data_batches = []
7983
for i in range(self.num_conditions):
80-
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
84+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
8185
data_batches.append(data_batch)
8286
return data_batches
8387

src/diffusers/guiders/guider_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da
174174
from ..pipelines.modular_pipeline import BlockState
175175

176176
if input_fields is None:
177-
raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.")
177+
raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.")
178178
data_batch = {}
179179
for key, value in input_fields.items():
180180
try:
@@ -186,7 +186,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da
186186
# We've already checked that value is a string or a tuple of strings with length 2
187187
pass
188188
except AttributeError:
189-
raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.")
189+
logger.warning(f"`data` does not have attribute(s) {value}, skipping.")
190190
data_batch[cls._identifier_key] = identifier
191191
return BlockState(**data_batch)
192192

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import List, Optional, Union, TYPE_CHECKING
16+
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
1717

1818
import torch
1919

@@ -156,7 +156,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
156156
for hook_name in self._skip_layer_hook_names:
157157
registry.remove_hook(hook_name, recurse=True)
158158

159-
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
159+
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
160+
161+
if input_fields is None:
162+
input_fields = self._input_fields
163+
160164
if self.num_conditions == 1:
161165
tuple_indices = [0]
162166
input_predictions = ["pred_cond"]
@@ -168,7 +172,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
168172
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
169173
data_batches = []
170174
for i in range(self.num_conditions):
171-
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
175+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
172176
data_batches.append(data_batch)
173177
return data_batches
174178

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import List, Optional, Union, TYPE_CHECKING
16+
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
1717

1818
import torch
1919

@@ -149,7 +149,11 @@ def cleanup_models(self, denoiser: torch.nn.Module):
149149
for hook_name in self._seg_layer_hook_names:
150150
registry.remove_hook(hook_name, recurse=True)
151151

152-
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
152+
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
153+
154+
if input_fields is None:
155+
input_fields = self._input_fields
156+
153157
if self.num_conditions == 1:
154158
tuple_indices = [0]
155159
input_predictions = ["pred_cond"]
@@ -161,7 +165,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
161165
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
162166
data_batches = []
163167
for i in range(self.num_conditions):
164-
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
168+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
165169
data_batches.append(data_batch)
166170
return data_batches
167171

src/diffusers/guiders/tangential_classifier_free_guidance.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING
16+
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
1717

1818
import torch
1919

@@ -62,11 +62,15 @@ def __init__(
6262
self.guidance_rescale = guidance_rescale
6363
self.use_original_formulation = use_original_formulation
6464

65-
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
65+
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
66+
67+
if input_fields is None:
68+
input_fields = self._input_fields
69+
6670
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
6771
data_batches = []
6872
for i in range(self.num_conditions):
69-
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
73+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
7074
data_batches.append(data_batch)
7175
return data_batches
7276

0 commit comments

Comments
 (0)