11# Copyright (c) OpenMMLab. All rights reserved.
2+ from typing import List , Optional , Union
3+
24import torch
5+ from torch import Tensor
36from torch .nn .parallel ._functions import _get_stream
47
58
6- def scatter (input , devices , streams = None ):
9+ def scatter (input : Union [List , Tensor ],
10+ devices : List ,
11+ streams : Optional [List ] = None ) -> Union [List , Tensor ]:
712 """Scatters tensor across multiple GPUs."""
813 if streams is None :
914 streams = [None ] * len (devices )
@@ -15,7 +20,7 @@ def scatter(input, devices, streams=None):
1520 [streams [i // chunk_size ]]) for i in range (len (input ))
1621 ]
1722 return outputs
18- elif isinstance (input , torch . Tensor ):
23+ elif isinstance (input , Tensor ):
1924 output = input .contiguous ()
2025 # TODO: copy to a pinned buffer first (if copying from CPU)
2126 stream = streams [0 ] if output .numel () > 0 else None
@@ -28,14 +33,15 @@ def scatter(input, devices, streams=None):
2833 raise Exception (f'Unknown type { type (input )} .' )
2934
3035
31- def synchronize_stream (output , devices , streams ):
36+ def synchronize_stream (output : Union [List , Tensor ], devices : List ,
37+ streams : List ) -> None :
3238 if isinstance (output , list ):
3339 chunk_size = len (output ) // len (devices )
3440 for i in range (len (devices )):
3541 for j in range (chunk_size ):
3642 synchronize_stream (output [i * chunk_size + j ], [devices [i ]],
3743 [streams [i ]])
38- elif isinstance (output , torch . Tensor ):
44+ elif isinstance (output , Tensor ):
3945 if output .numel () != 0 :
4046 with torch .cuda .device (devices [0 ]):
4147 main_stream = torch .cuda .current_stream ()
@@ -45,14 +51,14 @@ def synchronize_stream(output, devices, streams):
4551 raise Exception (f'Unknown type { type (output )} .' )
4652
4753
48- def get_input_device (input ) :
54+ def get_input_device (input : Union [ List , Tensor ]) -> int :
4955 if isinstance (input , list ):
5056 for item in input :
5157 input_device = get_input_device (item )
5258 if input_device != - 1 :
5359 return input_device
5460 return - 1
55- elif isinstance (input , torch . Tensor ):
61+ elif isinstance (input , Tensor ):
5662 return input .get_device () if input .is_cuda else - 1
5763 else :
5864 raise Exception (f'Unknown type { type (input )} .' )
@@ -61,7 +67,7 @@ def get_input_device(input):
6167class Scatter :
6268
6369 @staticmethod
64- def forward (target_gpus , input ) :
70+ def forward (target_gpus : List [ int ] , input : Union [ List , Tensor ]) -> tuple :
6571 input_device = get_input_device (input )
6672 streams = None
6773 if input_device == - 1 and target_gpus != [- 1 ]:
0 commit comments