Skip to content

Commit 3834391

Browse files
fix(backend): revert non-blocking device transfer
In #6490 we enabled non-blocking torch device transfers throughout the model manager's memory management code. When using this torch feature, torch attempts to wait until the tensor transfer has completed before allowing any access to the tensor. Theoretically, that should make this a safe feature to use. This provides a small performance improvement but causes race conditions in some situations. Specific platforms/systems are affected, and complicated data dependencies can make this unsafe. - Intermittent black images on MPS devices - reported on discord and #6545, fixed with special handling in #6549. - Intermittent OOMs and black images on a P4000 GPU on Windows - reported in #6613, fixed in this commit. On my system, I haven't experience any issues with generation, but targeted testing of non-blocking ops did expose a race condition when moving tensors from CUDA to CPU. One workaround is to use torch streams with manual sync points. Our application logic is complicated enough that this would be a lot of work and feels ripe for edge cases and missed spots. Much safer is to fully revert non-locking - which is what this change does.
1 parent 5a0c998 commit 3834391

File tree

8 files changed

+43
-115
lines changed

8 files changed

+43
-115
lines changed

invokeai/backend/ip_adapter/ip_adapter.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,14 @@ def __init__(
124124
self.device, dtype=self.dtype
125125
)
126126

127-
def to(
128-
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
129-
):
127+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
130128
if device is not None:
131129
self.device = device
132130
if dtype is not None:
133131
self.dtype = dtype
134132

135-
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
136-
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
133+
self._image_proj_model.to(device=self.device, dtype=self.dtype)
134+
self.attn_weights.to(device=self.device, dtype=self.dtype)
137135

138136
def calc_size(self) -> int:
139137
# HACK(ryand): Fix this issue with circular imports.

invokeai/backend/lora.py

Lines changed: 29 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from invokeai.backend.model_manager import BaseModelType
1313
from invokeai.backend.raw_model import RawModel
14-
from invokeai.backend.util.devices import TorchDevice
1514

1615

1716
class LoRALayerBase:
@@ -57,14 +56,9 @@ def calc_size(self) -> int:
5756
model_size += val.nelement() * val.element_size()
5857
return model_size
5958

60-
def to(
61-
self,
62-
device: Optional[torch.device] = None,
63-
dtype: Optional[torch.dtype] = None,
64-
non_blocking: bool = False,
65-
) -> None:
59+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
6660
if self.bias is not None:
67-
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
61+
self.bias = self.bias.to(device=device, dtype=dtype)
6862

6963

7064
# TODO: find and debug lora/locon with bias
@@ -106,19 +100,14 @@ def calc_size(self) -> int:
106100
model_size += val.nelement() * val.element_size()
107101
return model_size
108102

109-
def to(
110-
self,
111-
device: Optional[torch.device] = None,
112-
dtype: Optional[torch.dtype] = None,
113-
non_blocking: bool = False,
114-
) -> None:
115-
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
103+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
104+
super().to(device=device, dtype=dtype)
116105

117-
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
118-
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
106+
self.up = self.up.to(device=device, dtype=dtype)
107+
self.down = self.down.to(device=device, dtype=dtype)
119108

120109
if self.mid is not None:
121-
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
110+
self.mid = self.mid.to(device=device, dtype=dtype)
122111

123112

124113
class LoHALayer(LoRALayerBase):
@@ -167,23 +156,18 @@ def calc_size(self) -> int:
167156
model_size += val.nelement() * val.element_size()
168157
return model_size
169158

170-
def to(
171-
self,
172-
device: Optional[torch.device] = None,
173-
dtype: Optional[torch.dtype] = None,
174-
non_blocking: bool = False,
175-
) -> None:
159+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
176160
super().to(device=device, dtype=dtype)
177161

178-
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
179-
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
162+
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
163+
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
180164
if self.t1 is not None:
181-
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
165+
self.t1 = self.t1.to(device=device, dtype=dtype)
182166

183-
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
184-
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
167+
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
168+
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
185169
if self.t2 is not None:
186-
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
170+
self.t2 = self.t2.to(device=device, dtype=dtype)
187171

188172

189173
class LoKRLayer(LoRALayerBase):
@@ -264,32 +248,27 @@ def calc_size(self) -> int:
264248
model_size += val.nelement() * val.element_size()
265249
return model_size
266250

267-
def to(
268-
self,
269-
device: Optional[torch.device] = None,
270-
dtype: Optional[torch.dtype] = None,
271-
non_blocking: bool = False,
272-
) -> None:
251+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
273252
super().to(device=device, dtype=dtype)
274253

275254
if self.w1 is not None:
276255
self.w1 = self.w1.to(device=device, dtype=dtype)
277256
else:
278257
assert self.w1_a is not None
279258
assert self.w1_b is not None
280-
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
281-
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
259+
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
260+
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
282261

283262
if self.w2 is not None:
284-
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
263+
self.w2 = self.w2.to(device=device, dtype=dtype)
285264
else:
286265
assert self.w2_a is not None
287266
assert self.w2_b is not None
288-
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
289-
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
267+
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
268+
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
290269

291270
if self.t2 is not None:
292-
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
271+
self.t2 = self.t2.to(device=device, dtype=dtype)
293272

294273

295274
class FullLayer(LoRALayerBase):
@@ -319,15 +298,10 @@ def calc_size(self) -> int:
319298
model_size += self.weight.nelement() * self.weight.element_size()
320299
return model_size
321300

322-
def to(
323-
self,
324-
device: Optional[torch.device] = None,
325-
dtype: Optional[torch.dtype] = None,
326-
non_blocking: bool = False,
327-
) -> None:
301+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
328302
super().to(device=device, dtype=dtype)
329303

330-
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
304+
self.weight = self.weight.to(device=device, dtype=dtype)
331305

332306

333307
class IA3Layer(LoRALayerBase):
@@ -359,16 +333,11 @@ def calc_size(self) -> int:
359333
model_size += self.on_input.nelement() * self.on_input.element_size()
360334
return model_size
361335

362-
def to(
363-
self,
364-
device: Optional[torch.device] = None,
365-
dtype: Optional[torch.dtype] = None,
366-
non_blocking: bool = False,
367-
):
336+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
368337
super().to(device=device, dtype=dtype)
369338

370-
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
371-
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
339+
self.weight = self.weight.to(device=device, dtype=dtype)
340+
self.on_input = self.on_input.to(device=device, dtype=dtype)
372341

373342

374343
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@@ -390,15 +359,10 @@ def __init__(
390359
def name(self) -> str:
391360
return self._name
392361

393-
def to(
394-
self,
395-
device: Optional[torch.device] = None,
396-
dtype: Optional[torch.dtype] = None,
397-
non_blocking: bool = False,
398-
) -> None:
362+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
399363
# TODO: try revert if exception?
400364
for _key, layer in self.layers.items():
401-
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
365+
layer.to(device=device, dtype=dtype)
402366

403367
def calc_size(self) -> int:
404368
model_size = 0
@@ -521,7 +485,7 @@ def from_checkpoint(
521485
# lower memory consumption by removing already parsed layer values
522486
state_dict[layer_key].clear()
523487

524-
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
488+
layer.to(device=device, dtype=dtype)
525489
model.layers[layer_key] = layer
526490

527491
return model

invokeai/backend/model_manager/load/model_cache/model_cache_default.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,9 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
289289
else:
290290
new_dict: Dict[str, torch.Tensor] = {}
291291
for k, v in cache_entry.state_dict.items():
292-
new_dict[k] = v.to(
293-
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
294-
)
292+
new_dict[k] = v.to(target_device, copy=True)
295293
cache_entry.model.load_state_dict(new_dict, assign=True)
296-
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
294+
cache_entry.model.to(target_device)
297295
cache_entry.device = target_device
298296
except Exception as e: # blow away cache entry
299297
self._delete_cache_entry(cache_entry)

invokeai/backend/model_patcher.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,12 @@ def apply_lora(
139139
# We intentionally move to the target device first, then cast. Experimentally, this was found to
140140
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
141141
# same thing in a single call to '.to(...)'.
142-
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
143-
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
142+
layer.to(device=device)
143+
layer.to(dtype=torch.float32)
144144
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
145145
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
146146
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
147-
layer.to(
148-
device=TorchDevice.CPU_DEVICE,
149-
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
150-
)
147+
layer.to(device=TorchDevice.CPU_DEVICE)
151148

152149
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
153150
if module.weight.shape != layer_weight.shape:
@@ -156,17 +153,15 @@ def apply_lora(
156153
layer_weight = layer_weight.reshape(module.weight.shape)
157154

158155
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
159-
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
156+
module.weight += layer_weight.to(dtype=dtype)
160157

161158
yield # wait for context manager exit
162159

163160
finally:
164161
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
165162
with torch.no_grad():
166163
for module_key, weight in original_weights.items():
167-
model.get_submodule(module_key).weight.copy_(
168-
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
169-
)
164+
model.get_submodule(module_key).weight.copy_(weight)
170165

171166
@classmethod
172167
@contextmanager

invokeai/backend/onnx/onnx_runtime.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,7 @@ def __call__(self, **kwargs):
190190
return self.session.run(None, inputs)
191191

192192
# compatability with RawModel ABC
193-
def to(
194-
self,
195-
device: Optional[torch.device] = None,
196-
dtype: Optional[torch.dtype] = None,
197-
non_blocking: bool = False,
198-
) -> None:
193+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
199194
pass
200195

201196
# compatability with diffusers load code

invokeai/backend/raw_model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,5 @@ class RawModel(ABC):
2020
"""Abstract base class for 'Raw' model wrappers."""
2121

2222
@abstractmethod
23-
def to(
24-
self,
25-
device: Optional[torch.device] = None,
26-
dtype: Optional[torch.dtype] = None,
27-
non_blocking: bool = False,
28-
) -> None:
23+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
2924
pass

invokeai/backend/textual_inversion.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,12 @@ def from_checkpoint(
6565

6666
return result
6767

68-
def to(
69-
self,
70-
device: Optional[torch.device] = None,
71-
dtype: Optional[torch.dtype] = None,
72-
non_blocking: bool = False,
73-
) -> None:
68+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
7469
if not torch.cuda.is_available():
7570
return
7671
for emb in [self.embedding, self.embedding_2]:
7772
if emb is not None:
78-
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
73+
emb.to(device=device, dtype=dtype)
7974

8075
def calc_size(self) -> int:
8176
"""Get the size of this model in bytes."""

invokeai/backend/util/devices.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,3 @@ def empty_cache(cls) -> None:
112112
@classmethod
113113
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
114114
return NAME_TO_PRECISION[precision_name]
115-
116-
@staticmethod
117-
def get_non_blocking(to_device: torch.device) -> bool:
118-
"""Return the non_blocking flag to be used when moving a tensor to a given device.
119-
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
120-
When moving _from_ MPS, we can use non-blocking operations.
121-
122-
See:
123-
- https://github.com/pytorch/pytorch/issues/107455
124-
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
125-
"""
126-
return False if to_device.type == "mps" else True

0 commit comments

Comments
 (0)