- 
                Notifications
    
You must be signed in to change notification settings  - Fork 6.5k
 
          [feat] implement record_stream when using CUDA streams during group offloading
          #11081
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
ffce2d1
              f25ea18
              2a28f6d
              41ea4c8
              f5b69b0
              9281e84
              637f84e
              612136f
              d5afea5
              fb59f36
              4a6eeba
              87a93fe
              1d4ca61
              535dcd1
              2ff9112
              b4deedc
              622aba7
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -56,6 +56,7 @@ def __init__( | |
| buffers: Optional[List[torch.Tensor]] = None, | ||
| non_blocking: bool = False, | ||
| stream: Optional[torch.cuda.Stream] = None, | ||
| record_stream: Optional[bool] = False, | ||
| cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, | ||
| onload_self: bool = True, | ||
| ) -> None: | ||
| 
        
          
        
         | 
    @@ -68,33 +69,47 @@ def __init__( | |
| self.buffers = buffers | ||
| self.non_blocking = non_blocking or stream is not None | ||
| self.stream = stream | ||
| self.record_stream = record_stream | ||
| self.cpu_param_dict = cpu_param_dict | ||
| self.onload_self = onload_self | ||
| 
     | 
||
| if self.stream is not None and self.cpu_param_dict is None: | ||
| raise ValueError("cpu_param_dict must be provided when using stream for data transfer.") | ||
| raise ValueError("`cpu_param_dict` must be provided when using stream for data transfer.") | ||
| 
     | 
||
| if self.record_stream and not self.stream: | ||
| raise ValueError("`record_stream` cannot be True when `stream` is None.") | ||
| 
     | 
||
| def onload_(self): | ||
| r"""Onloads the group of modules to the onload_device.""" | ||
| context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) | ||
| current_stream = torch.cuda.current_stream() if self.record_stream else None | ||
| 
     | 
||
| if self.stream is not None: | ||
| # Wait for previous Host->Device transfer to complete | ||
| self.stream.synchronize() | ||
| 
     | 
||
| with context: | ||
| for group_module in self.modules: | ||
| group_module.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if self.record_stream: | ||
| for param in group_module.parameters(): | ||
| param.data.record_stream(current_stream) | ||
                
       | 
||
| if self.parameters is not None: | ||
| for param in self.parameters: | ||
| param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if self.record_stream: | ||
| param.data.record_stream(current_stream) | ||
| if self.buffers is not None: | ||
| for buffer in self.buffers: | ||
| buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if self.record_stream: | ||
| buffer.data.record_stream(current_stream) | ||
| 
     | 
||
| def offload_(self): | ||
| r"""Offloads the group of modules to the offload_device.""" | ||
| if self.stream is not None: | ||
| torch.cuda.current_stream().synchronize() | ||
| if not self.record_stream: | ||
| torch.cuda.current_stream().synchronize() | ||
| for group_module in self.modules: | ||
| for param in group_module.parameters(): | ||
| param.data = self.cpu_param_dict[param] | ||
| 
          
            
          
           | 
    @@ -268,6 +283,7 @@ def apply_group_offloading( | |
| num_blocks_per_group: Optional[int] = None, | ||
| non_blocking: bool = False, | ||
| use_stream: bool = False, | ||
| record_stream: bool = False, | ||
| ) -> None: | ||
| r""" | ||
| Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and | ||
| 
          
            
          
           | 
    @@ -314,6 +330,7 @@ def apply_group_offloading( | |
| use_stream (`bool`, defaults to `False`): | ||
| If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for | ||
| overlapping computation and data transfer. | ||
| record_stream: TODO | ||
| 
     | 
||
| Example: | ||
| ```python | ||
| 
          
            
          
           | 
    @@ -349,10 +366,10 @@ def apply_group_offloading( | |
| raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") | ||
| 
     | 
||
| _apply_group_offloading_block_level( | ||
| module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream | ||
| module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, record_stream | ||
| ) | ||
| elif offload_type == "leaf_level": | ||
| _apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream) | ||
| _apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream, record_stream) | ||
| else: | ||
| raise ValueError(f"Unsupported offload_type: {offload_type}") | ||
| 
     | 
||
| 
        
          
        
         | 
    @@ -364,6 +381,7 @@ def _apply_group_offloading_block_level( | |
| onload_device: torch.device, | ||
| non_blocking: bool, | ||
| stream: Optional[torch.cuda.Stream] = None, | ||
| record_stream: Optional[bool] = False, | ||
| ) -> None: | ||
| r""" | ||
| This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to | ||
| 
        
          
        
         | 
    @@ -382,6 +400,7 @@ def _apply_group_offloading_block_level( | |
| stream (`torch.cuda.Stream`, *optional*): | ||
| If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful | ||
| for overlapping computation and data transfer. | ||
| record_stream: TODO | ||
| """ | ||
| 
     | 
||
| # Create a pinned CPU parameter dict for async data transfer if streams are to be used | ||
| 
          
            
          
           | 
    @@ -411,6 +430,7 @@ def _apply_group_offloading_block_level( | |
| onload_leader=current_modules[0], | ||
| non_blocking=non_blocking, | ||
| stream=stream, | ||
| record_stream=record_stream, | ||
| cpu_param_dict=cpu_param_dict, | ||
| onload_self=stream is None, | ||
| ) | ||
| 
          
            
          
           | 
    @@ -448,6 +468,7 @@ def _apply_group_offloading_block_level( | |
| buffers=buffers, | ||
| non_blocking=False, | ||
| stream=None, | ||
| record_stream=False, | ||
| cpu_param_dict=None, | ||
| onload_self=True, | ||
| ) | ||
| 
        
          
        
         | 
    @@ -461,6 +482,7 @@ def _apply_group_offloading_leaf_level( | |
| onload_device: torch.device, | ||
| non_blocking: bool, | ||
| stream: Optional[torch.cuda.Stream] = None, | ||
| record_stream: Optional[bool] = False, | ||
| ) -> None: | ||
| r""" | ||
| This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory | ||
| 
        
          
        
         | 
    @@ -481,6 +503,7 @@ def _apply_group_offloading_leaf_level( | |
| stream (`torch.cuda.Stream`, *optional*): | ||
| If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful | ||
| for overlapping computation and data transfer. | ||
| record_stream: TODO | ||
| """ | ||
| 
     | 
||
| # Create a pinned CPU parameter dict for async data transfer if streams are to be used | ||
| 
        
          
        
         | 
    @@ -503,6 +526,7 @@ def _apply_group_offloading_leaf_level( | |
| onload_leader=submodule, | ||
| non_blocking=non_blocking, | ||
| stream=stream, | ||
| record_stream=record_stream, | ||
| cpu_param_dict=cpu_param_dict, | ||
| onload_self=True, | ||
| ) | ||
| 
          
            
          
           | 
    @@ -548,6 +572,7 @@ def _apply_group_offloading_leaf_level( | |
| buffers=buffers, | ||
| non_blocking=non_blocking, | ||
| stream=stream, | ||
| record_stream=record_stream, | ||
| cpu_param_dict=cpu_param_dict, | ||
| onload_self=True, | ||
| ) | ||
| 
        
          
        
         | 
    @@ -567,6 +592,7 @@ def _apply_group_offloading_leaf_level( | |
| buffers=None, | ||
| non_blocking=False, | ||
| stream=None, | ||
| record_stream=False, | ||
| cpu_param_dict=None, | ||
| onload_self=True, | ||
| ) | ||
| 
          
            
          
           | 
    ||
Uh oh!
There was an error while loading. Please reload this page.