File tree Expand file tree Collapse file tree 1 file changed +0
-19
lines changed
Expand file tree Collapse file tree 1 file changed +0
-19
lines changed Original file line number Diff line number Diff line change 77from utils import cleanup_parallel_strategy , fp32_allclose
88
99
10- def all_gather_vlen (tensor : torch .Tensor , group = None , dim = 0 ) -> list [torch .Tensor ]:
11- """Gather tensors with the same number of dimensions but different lengths.
12-
13- Credit: https://stackoverflow.com/a/78934638
14- """
15- world_size = dist .get_world_size (group = group )
16- # Gather lengths first
17- shape = torch .as_tensor (tensor .shape , device = tensor .device )
18- shapes = [torch .empty_like (shape ) for _ in range (world_size )]
19- dist .all_gather (shapes , shape , group = group )
20- # Gather data
21- inputs = [tensor ] * world_size
22- outputs = [
23- torch .empty (* _shape , dtype = tensor .dtype , device = tensor .device )
24- for _shape in shapes
25- ]
26- dist .all_to_all (outputs , inputs , group = group )
27- return torch .cat (outputs , dim = dim )
28-
2910
3011@pytest .fixture (scope = "module" )
3112def parallel_strategy (device : torch .device ):
You can’t perform that action at this time.
0 commit comments