Skip to content

Commit 2afea72

Browse files
Sai-Suraj-27sayakpaulyiyixuxu
authored
refactor: Refactored code by Merging isinstance calls (#7710)
* Merged isinstance calls to make the code simpler. * Corrected formatting errors using ruff. --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 0f111ab commit 2afea72

File tree

11 files changed

+11
-27
lines changed

11 files changed

+11
-27
lines changed

examples/community/pipeline_stable_diffusion_upscale_ldm3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def check_inputs(
460460
)
461461

462462
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
463-
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
463+
if isinstance(image, (list, np.ndarray, torch.Tensor)):
464464
if prompt is not None and isinstance(prompt, str):
465465
batch_size = 1
466466
elif prompt is not None and isinstance(prompt, list):

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: i
685685
positive_len = 768
686686
if isinstance(cross_attention_dim, int):
687687
positive_len = cross_attention_dim
688-
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
688+
elif isinstance(cross_attention_dim, (list, tuple)):
689689
positive_len = cross_attention_dim[0]
690690

691691
feature_type = "text-only" if attention_type == "gated" else "text-image"

src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def __init__(
817817
positive_len = 768
818818
if isinstance(cross_attention_dim, int):
819819
positive_len = cross_attention_dim
820-
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
820+
elif isinstance(cross_attention_dim, (list, tuple)):
821821
positive_len = cross_attention_dim[0]
822822

823823
feature_type = "text-only" if attention_type == "gated" else "text-image"

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def check_inputs(
197197
)
198198

199199
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
200-
if isinstance(image, list) or isinstance(image, np.ndarray):
200+
if isinstance(image, (list, np.ndarray)):
201201
if prompt is not None and isinstance(prompt, str):
202202
batch_size = 1
203203
elif prompt is not None and isinstance(prompt, list):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def check_inputs(self, prompt, image, callback_steps):
221221
)
222222

223223
# verify batch size of prompt and image are same if image is a list or tensor
224-
if isinstance(image, list) or isinstance(image, torch.Tensor):
224+
if isinstance(image, (list, torch.Tensor)):
225225
if isinstance(prompt, str):
226226
batch_size = 1
227227
else:

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def check_inputs(
468468
)
469469

470470
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
471-
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
471+
if isinstance(image, (list, np.ndarray, torch.Tensor)):
472472
if prompt is not None and isinstance(prompt, str):
473473
batch_size = 1
474474
elif prompt is not None and isinstance(prompt, list):

src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def preprocess(image):
185185
def preprocess_mask(mask, batch_size: int = 1):
186186
if not isinstance(mask, torch.Tensor):
187187
# preprocess mask
188-
if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray):
188+
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
189189
mask = [mask]
190190

191191
if isinstance(mask, list):

src/diffusers/schedulers/scheduling_consistency_models.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,7 @@ def step(
347347
otherwise a tuple is returned where the first element is the sample tensor.
348348
"""
349349

350-
if (
351-
isinstance(timestep, int)
352-
or isinstance(timestep, torch.IntTensor)
353-
or isinstance(timestep, torch.LongTensor)
354-
):
350+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
355351
raise ValueError(
356352
(
357353
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,7 @@ def step(
310310
returned, otherwise a tuple is returned where the first element is the sample tensor.
311311
"""
312312

313-
if (
314-
isinstance(timestep, int)
315-
or isinstance(timestep, torch.IntTensor)
316-
or isinstance(timestep, torch.LongTensor)
317-
):
313+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
318314
raise ValueError(
319315
(
320316
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,11 +375,7 @@ def step(
375375
376376
"""
377377

378-
if (
379-
isinstance(timestep, int)
380-
or isinstance(timestep, torch.IntTensor)
381-
or isinstance(timestep, torch.LongTensor)
382-
):
378+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
383379
raise ValueError(
384380
(
385381
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"

0 commit comments

Comments
 (0)