Skip to content

Commit d50e321

Browse files
authored
Support SD2 attention slicing (#1397)
* Support SD2 attention slicing * Support SD2 attention slicing * Add more copies * Use attn_num_head_channels in blocks * fix-copies * Update tests * fix imports
1 parent 8e2c4cd commit d50e321

16 files changed

+892
-81
lines changed

src/diffusers/models/unet_2d_blocks.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -404,15 +404,17 @@ def __init__(
404404
self.resnets = nn.ModuleList(resnets)
405405

406406
def set_attention_slice(self, slice_size):
407-
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
407+
head_dims = self.attn_num_head_channels
408+
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
409+
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
408410
raise ValueError(
409-
f"Make sure slice_size {slice_size} is a divisor of "
410-
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
411+
f"Make sure slice_size {slice_size} is a common divisor of "
412+
f"the number of heads used in cross_attention: {head_dims}"
411413
)
412-
if slice_size is not None and slice_size > self.attn_num_head_channels:
414+
if slice_size is not None and slice_size > min(head_dims):
413415
raise ValueError(
414-
f"Chunk_size {slice_size} has to be smaller or equal to "
415-
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
416+
f"slice_size {slice_size} has to be smaller or equal to "
417+
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
416418
)
417419

418420
for attn in self.attentions:
@@ -600,15 +602,17 @@ def __init__(
600602
self.gradient_checkpointing = False
601603

602604
def set_attention_slice(self, slice_size):
603-
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
605+
head_dims = self.attn_num_head_channels
606+
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
607+
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
604608
raise ValueError(
605-
f"Make sure slice_size {slice_size} is a divisor of "
606-
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
609+
f"Make sure slice_size {slice_size} is a common divisor of "
610+
f"the number of heads used in cross_attention: {head_dims}"
607611
)
608-
if slice_size is not None and slice_size > self.attn_num_head_channels:
612+
if slice_size is not None and slice_size > min(head_dims):
609613
raise ValueError(
610-
f"Chunk_size {slice_size} has to be smaller or equal to "
611-
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
614+
f"slice_size {slice_size} has to be smaller or equal to "
615+
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
612616
)
613617

614618
for attn in self.attentions:
@@ -1197,15 +1201,17 @@ def __init__(
11971201
self.gradient_checkpointing = False
11981202

11991203
def set_attention_slice(self, slice_size):
1200-
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
1204+
head_dims = self.attn_num_head_channels
1205+
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
1206+
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
12011207
raise ValueError(
1202-
f"Make sure slice_size {slice_size} is a divisor of "
1203-
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1208+
f"Make sure slice_size {slice_size} is a common divisor of "
1209+
f"the number of heads used in cross_attention: {head_dims}"
12041210
)
1205-
if slice_size is not None and slice_size > self.attn_num_head_channels:
1211+
if slice_size is not None and slice_size > min(head_dims):
12061212
raise ValueError(
1207-
f"Chunk_size {slice_size} has to be smaller or equal to "
1208-
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1213+
f"slice_size {slice_size} has to be smaller or equal to "
1214+
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
12091215
)
12101216

12111217
for attn in self.attentions:

src/diffusers/models/unet_2d_condition.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,17 @@ def __init__(
229229
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
230230

231231
def set_attention_slice(self, slice_size):
232-
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
232+
head_dims = self.config.attention_head_dim
233+
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
234+
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
233235
raise ValueError(
234-
f"Make sure slice_size {slice_size} is a divisor of "
235-
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
236+
f"Make sure slice_size {slice_size} is a common divisor of "
237+
f"the number of heads used in cross_attention: {head_dims}"
236238
)
237-
if slice_size is not None and slice_size > self.config.attention_head_dim:
239+
if slice_size is not None and slice_size > min(head_dims):
238240
raise ValueError(
239-
f"Chunk_size {slice_size} has to be smaller or equal to "
240-
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
241+
f"slice_size {slice_size} has to be smaller or equal to "
242+
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
241243
)
242244

243245
for block in self.down_blocks:

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
198198
`attention_head_dim` must be a multiple of `slice_size`.
199199
"""
200200
if slice_size == "auto":
201-
# half the attention head size is usually a good trade-off between
202-
# speed and memory
203-
slice_size = self.unet.config.attention_head_dim // 2
201+
if isinstance(self.unet.config.attention_head_dim, int):
202+
# half the attention head size is usually a good trade-off between
203+
# speed and memory
204+
slice_size = self.unet.config.attention_head_dim // 2
205+
else:
206+
# if `attention_head_dim` is a list, take the smallest head size
207+
slice_size = min(self.unet.config.attention_head_dim)
208+
204209
self.unet.set_attention_slice(slice_size)
205210

206211
def disable_attention_slicing(self):

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
193193
`attention_head_dim` must be a multiple of `slice_size`.
194194
"""
195195
if slice_size == "auto":
196-
# half the attention head size is usually a good trade-off between
197-
# speed and memory
198-
slice_size = self.unet.config.attention_head_dim // 2
196+
if isinstance(self.unet.config.attention_head_dim, int):
197+
# half the attention head size is usually a good trade-off between
198+
# speed and memory
199+
slice_size = self.unet.config.attention_head_dim // 2
200+
else:
201+
# if `attention_head_dim` is a list, take the smallest head size
202+
slice_size = min(self.unet.config.attention_head_dim)
203+
199204
self.unet.set_attention_slice(slice_size)
200205

201206
def disable_attention_slicing(self):

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
224224
`attention_head_dim` must be a multiple of `slice_size`.
225225
"""
226226
if slice_size == "auto":
227-
# half the attention head size is usually a good trade-off between
228-
# speed and memory
229-
slice_size = self.unet.config.attention_head_dim // 2
227+
if isinstance(self.unet.config.attention_head_dim, int):
228+
# half the attention head size is usually a good trade-off between
229+
# speed and memory
230+
slice_size = self.unet.config.attention_head_dim // 2
231+
else:
232+
# if `attention_head_dim` is a list, take the smallest head size
233+
slice_size = min(self.unet.config.attention_head_dim)
234+
230235
self.unet.set_attention_slice(slice_size)
231236

232237
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
197197
`attention_head_dim` must be a multiple of `slice_size`.
198198
"""
199199
if slice_size == "auto":
200-
# half the attention head size is usually a good trade-off between
201-
# speed and memory
202-
slice_size = self.unet.config.attention_head_dim // 2
200+
if isinstance(self.unet.config.attention_head_dim, int):
201+
# half the attention head size is usually a good trade-off between
202+
# speed and memory
203+
slice_size = self.unet.config.attention_head_dim // 2
204+
else:
205+
# if `attention_head_dim` is a list, take the smallest head size
206+
slice_size = min(self.unet.config.attention_head_dim)
207+
203208
self.unet.set_attention_slice(slice_size)
204209

205210
def disable_attention_slicing(self):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
169169
`attention_head_dim` must be a multiple of `slice_size`.
170170
"""
171171
if slice_size == "auto":
172-
# half the attention head size is usually a good trade-off between
173-
# speed and memory
174-
slice_size = self.unet.config.attention_head_dim // 2
172+
if isinstance(self.unet.config.attention_head_dim, int):
173+
# half the attention head size is usually a good trade-off between
174+
# speed and memory
175+
slice_size = self.unet.config.attention_head_dim // 2
176+
else:
177+
# if `attention_head_dim` is a list, take the smallest head size
178+
slice_size = min(self.unet.config.attention_head_dim)
179+
175180
self.unet.set_attention_slice(slice_size)
176181

177182
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
193193
`attention_head_dim` must be a multiple of `slice_size`.
194194
"""
195195
if slice_size == "auto":
196-
# half the attention head size is usually a good trade-off between
197-
# speed and memory
198-
slice_size = self.unet.config.attention_head_dim // 2
196+
if isinstance(self.unet.config.attention_head_dim, int):
197+
# half the attention head size is usually a good trade-off between
198+
# speed and memory
199+
slice_size = self.unet.config.attention_head_dim // 2
200+
else:
201+
# if `attention_head_dim` is a list, take the smallest head size
202+
slice_size = min(self.unet.config.attention_head_dim)
203+
199204
self.unet.set_attention_slice(slice_size)
200205

201206
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
258258
`attention_head_dim` must be a multiple of `slice_size`.
259259
"""
260260
if slice_size == "auto":
261-
# half the attention head size is usually a good trade-off between
262-
# speed and memory
263-
slice_size = self.unet.config.attention_head_dim // 2
261+
if isinstance(self.unet.config.attention_head_dim, int):
262+
# half the attention head size is usually a good trade-off between
263+
# speed and memory
264+
slice_size = self.unet.config.attention_head_dim // 2
265+
else:
266+
# if `attention_head_dim` is a list, take the smallest head size
267+
slice_size = min(self.unet.config.attention_head_dim)
268+
264269
self.unet.set_attention_slice(slice_size)
265270

266271
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
206206
`attention_head_dim` must be a multiple of `slice_size`.
207207
"""
208208
if slice_size == "auto":
209-
# half the attention head size is usually a good trade-off between
210-
# speed and memory
211-
slice_size = self.unet.config.attention_head_dim // 2
209+
if isinstance(self.unet.config.attention_head_dim, int):
210+
# half the attention head size is usually a good trade-off between
211+
# speed and memory
212+
slice_size = self.unet.config.attention_head_dim // 2
213+
else:
214+
# if `attention_head_dim` is a list, take the smallest head size
215+
slice_size = min(self.unet.config.attention_head_dim)
216+
212217
self.unet.set_attention_slice(slice_size)
213218

214219
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing

0 commit comments

Comments
 (0)