Skip to content

Commit 68301b3

Browse files
authored
Merge branch 'main' into diffusions-xpu
2 parents a6b2065 + c36f848 commit 68301b3

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,8 @@ def _get_signature_types(cls):
16651665
signature_types[k] = (v.annotation,)
16661666
elif get_origin(v.annotation) == Union:
16671667
signature_types[k] = get_args(v.annotation)
1668+
elif get_origin(v.annotation) in [List, Dict, list, dict]:
1669+
signature_types[k] = (v.annotation,)
16681670
else:
16691671
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
16701672
return signature_types

src/diffusers/utils/torch_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def maybe_allow_in_graph(cls):
3838
def randn_tensor(
3939
shape: Union[Tuple, List],
4040
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
41-
device: Optional["torch.device"] = None,
41+
device: Optional[Union[str, "torch.device"]] = None,
4242
dtype: Optional["torch.dtype"] = None,
4343
layout: Optional["torch.layout"] = None,
4444
):
@@ -47,6 +47,8 @@ def randn_tensor(
4747
is always created on the CPU.
4848
"""
4949
# device on which tensor is created defaults to device
50+
if isinstance(device, str):
51+
device = torch.device(device)
5052
rand_device = device
5153
batch_size = shape[0]
5254

0 commit comments

Comments
 (0)