Skip to content

Commit 18454cf

Browse files
committed
device_map in load_model_dict_into_meta
1 parent 1871a69 commit 18454cf

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

src/diffusers/loaders/transformer_flux.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8282
if not low_cpu_mem_usage:
8383
image_projection.load_state_dict(updated_state_dict, strict=True)
8484
else:
85-
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
85+
device_map = {"": self.device}
86+
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
8687

8788
return image_projection
8889

@@ -151,9 +152,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
151152
if not low_cpu_mem_usage:
152153
attn_procs[name].load_state_dict(value_dict)
153154
else:
154-
device = self.device
155+
device_map = {"": self.device}
155156
dtype = self.dtype
156-
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
157+
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
157158

158159
key_id += 1
159160

src/diffusers/loaders/transformer_sd3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ def _convert_ip_adapter_attn_to_diffusers(
7575
if not low_cpu_mem_usage:
7676
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
7777
else:
78+
device_map = {"": self.device}
7879
load_model_dict_into_meta(
79-
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
80+
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
8081
)
8182

8283
return attn_procs
@@ -144,7 +145,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(
144145
if not low_cpu_mem_usage:
145146
image_proj.load_state_dict(updated_state_dict, strict=True)
146147
else:
147-
load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype)
148+
device_map = {"": self.device}
149+
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
148150

149151
return image_proj
150152

src/diffusers/loaders/unet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
753753
if not low_cpu_mem_usage:
754754
image_projection.load_state_dict(updated_state_dict, strict=True)
755755
else:
756-
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
756+
device_map = {"": self.device}
757+
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
757758

758759
return image_projection
759760

@@ -846,7 +847,8 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
846847
else:
847848
device = next(iter(value_dict.values())).device
848849
dtype = next(iter(value_dict.values())).dtype
849-
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
850+
device_map = {"": device}
851+
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
850852

851853
key_id += 2
852854

0 commit comments

Comments
 (0)