Skip to content

Commit 7383566

Browse files
committed
fix: CUDA error: out of memory on non_blocking calls
Removes non_blocking argument from all device to cpu transfers. In certain environments (e.g. WSL) large transfers will throw a CUDA memory error regardless of VRAM available. Adjusts stream synchronize for modest performance gains with cpu_offload. fixes #90, fixes #117
1 parent 36ee929 commit 7383566

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

OmniGen/scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def evict_previous_layer(self, layer_idx: int):
3838
prev_layer_idx = -1
3939
else:
4040
prev_layer_idx = (layer_idx - 1) % len(self)
41-
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
42-
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
41+
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu")
42+
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu")
4343

4444

4545
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
@@ -50,9 +50,9 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
5050
torch.cuda.current_stream().synchronize()
5151
self.evict_previous_layer(layer_idx)
5252
# Load current layer cache to its original device if not already there
53-
original_device = self.original_device[layer_idx]
53+
#original_device = self.original_device[layer_idx]
5454
# self.prefetch_stream.synchronize(original_device)
55-
torch.cuda.synchronize(self.prefetch_stream)
55+
self.prefetch_stream.synchronize()
5656
key_tensor = self.key_cache[layer_idx]
5757
value_tensor = self.value_cache[layer_idx]
5858

OmniGen/transformer.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,27 @@ def prefetch_layer(self, layer_idx: int, device: torch.device):
3333
"Starts prefetching the next layer cache"
3434
with torch.cuda.stream(self.prefetch_stream):
3535
# Prefetch next layer tensors to GPU
36-
for name, param in self.layers[layer_idx].named_parameters():
37-
param.data = param.data.to(device, non_blocking=True)
36+
self.layers[layer_idx] = self.layers[layer_idx].to(device, non_blocking=True)
3837

3938
def evict_previous_layer(self, layer_idx: int):
4039
"Moves the previous layer cache to the CPU"
4140
prev_layer_idx = layer_idx - 1
42-
for name, param in self.layers[prev_layer_idx].named_parameters():
43-
param.data = param.data.to("cpu", non_blocking=True)
41+
self.layers[prev_layer_idx] = self.layers[prev_layer_idx].to("cpu")
4442

4543
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
4644
# init stream
4745
if not hasattr(self, "prefetch_stream"):
4846
self.prefetch_stream = torch.cuda.Stream()
4947

5048
# delete previous layer
51-
torch.cuda.current_stream().synchronize()
52-
self.evict_previous_layer(layer_idx)
49+
# main stream sync shouldn't be necessary since all computation on iter i-1 is finished by iter i
50+
# torch.cuda.current_stream().synchronize()
51+
# avoid extra eviction of last layer
52+
if layer_idx > 0:
53+
self.evict_previous_layer(layer_idx)
5354

5455
# make sure the current layer is ready
55-
torch.cuda.synchronize(self.prefetch_stream)
56+
self.prefetch_stream.synchronize()
5657

5758
# load next layer
5859
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
@@ -133,10 +134,9 @@ def forward(
133134
all_self_attns = () if output_attentions else None
134135
next_decoder_cache = None
135136

136-
layer_idx = -1
137-
for decoder_layer in self.layers:
138-
layer_idx += 1
139-
137+
for layer_idx in range(len(self.layers)):
138+
# direct indexing since offloading may mutate self.layers during iteration
139+
decoder_layer = self.layers[layer_idx]
140140
if output_hidden_states:
141141
all_hidden_states += (hidden_states,)
142142

0 commit comments

Comments
 (0)