diff --git a/OmniGen/scheduler.py b/OmniGen/scheduler.py index ffa99cd..28001e9 100644 --- a/OmniGen/scheduler.py +++ b/OmniGen/scheduler.py @@ -37,8 +37,8 @@ def evict_previous_layer(self, layer_idx: int): prev_layer_idx = -1 else: prev_layer_idx = (layer_idx - 1) % len(self) - self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) - self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) + self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu") + self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu") def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: @@ -49,9 +49,9 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: torch.cuda.current_stream().synchronize() self.evict_previous_layer(layer_idx) # Load current layer cache to its original device if not already there - original_device = self.original_device[layer_idx] + #original_device = self.original_device[layer_idx] # self.prefetch_stream.synchronize(original_device) - torch.cuda.synchronize(self.prefetch_stream) + self.prefetch_stream.synchronize() key_tensor = self.key_cache[layer_idx] value_tensor = self.value_cache[layer_idx] diff --git a/OmniGen/transformer.py b/OmniGen/transformer.py index a2672a4..5858289 100644 --- a/OmniGen/transformer.py +++ b/OmniGen/transformer.py @@ -40,7 +40,7 @@ def evict_previous_layer(self, layer_idx: int): "Moves the previous layer cache to the CPU" prev_layer_idx = layer_idx - 1 for name, param in self.layers[prev_layer_idx].named_parameters(): - param.data = param.data.to("cpu", non_blocking=True) + param.data = param.data.to("cpu") def get_offlaod_layer(self, layer_idx: int, device: torch.device): # init stream @@ -48,11 +48,14 @@ def get_offlaod_layer(self, layer_idx: int, device: torch.device): self.prefetch_stream = torch.cuda.Stream() # delete previous layer - torch.cuda.current_stream().synchronize() - self.evict_previous_layer(layer_idx) + # main stream sync shouldn't be necessary since all computation on iter i-1 is finished by iter i + # torch.cuda.current_stream().synchronize() + # avoid extra eviction of last layer + if layer_idx > 0: + self.evict_previous_layer(layer_idx) # make sure the current layer is ready - torch.cuda.synchronize(self.prefetch_stream) + self.prefetch_stream.synchronize() # load next layer self.prefetch_layer((layer_idx + 1) % len(self.layers), device) @@ -133,10 +136,9 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - layer_idx = -1 - for decoder_layer in self.layers: - layer_idx += 1 - + for layer_idx in range(len(self.layers)): + # direct indexing since offloading may mutate self.layers during iteration + decoder_layer = self.layers[layer_idx] if output_hidden_states: all_hidden_states += (hidden_states,)