Skip to content

Commit 39407c1

Browse files
committed
unet.config.in_channels
1 parent 076f304 commit 39407c1

File tree

7 files changed

+7
-7
lines changed

7 files changed

+7
-7
lines changed

examples/community/adaptive_mask_inpainting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def __init__(
442442
unet._internal_dict = FrozenDict(new_config)
443443

444444
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
445-
if unet.config.in_channels != 9:
445+
if unet is not None and unet.config.in_channels != 9:
446446
logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
447447

448448
self.register_modules(

examples/community/stable_diffusion_reference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def __init__(
206206
new_config["sample_size"] = 64
207207
unet._internal_dict = FrozenDict(new_config)
208208
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
209-
if unet.config.in_channels != 4:
209+
if unet is not None and unet.config.in_channels != 4:
210210
logger.warning(
211211
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
212212
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"

examples/community/stable_diffusion_repaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(
261261
new_config["sample_size"] = 64
262262
unet._internal_dict = FrozenDict(new_config)
263263
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
264-
if unet.config.in_channels != 4:
264+
if unet is not None and unet.config.in_channels != 4:
265265
logger.warning(
266266
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
267267
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"

src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
175175
)
176176

177-
if unet.config.in_channels != 6:
177+
if unet is not None and unet.config.in_channels != 6:
178178
logger.warning(
179179
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
180180
)

src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(
176176
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
177177
)
178178

179-
if unet.config.in_channels != 6:
179+
if unet is not None and unet.config.in_channels != 6:
180180
logger.warning(
181181
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
182182
)

src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(
132132
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
133133
)
134134

135-
if unet.config.in_channels != 6:
135+
if unet is not None and unet.config.in_channels != 6:
136136
logger.warning(
137137
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
138138
)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def __init__(
241241
unet._internal_dict = FrozenDict(new_config)
242242

243243
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
244-
if unet.config.in_channels != 9:
244+
if unet is not None and unet.config.in_channels != 9:
245245
logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
246246

247247
self.register_modules(

0 commit comments

Comments
 (0)