|
| 1 | +""" |
| 2 | +This file contains primitives for multi-gpu communication. |
| 3 | +This is useful when doing distributed training. |
| 4 | +
|
| 5 | +deeply borrow from maskrcnn-benchmark and ST3D |
| 6 | +""" |
| 7 | + |
| 8 | +import pickle |
| 9 | +import time |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.distributed as dist |
| 13 | + |
| 14 | + |
| 15 | +def get_world_size(): |
| 16 | + if not dist.is_available(): |
| 17 | + return 1 |
| 18 | + if not dist.is_initialized(): |
| 19 | + return 1 |
| 20 | + return dist.get_world_size() |
| 21 | + |
| 22 | + |
| 23 | +def get_rank(): |
| 24 | + if not dist.is_available(): |
| 25 | + return 0 |
| 26 | + if not dist.is_initialized(): |
| 27 | + return 0 |
| 28 | + return dist.get_rank() |
| 29 | + |
| 30 | + |
| 31 | +def is_main_process(): |
| 32 | + return get_rank() == 0 |
| 33 | + |
| 34 | + |
| 35 | +def synchronize(): |
| 36 | + """ |
| 37 | + Helper function to synchronize (barrier) among all processes when |
| 38 | + using distributed training |
| 39 | + """ |
| 40 | + if not dist.is_available(): |
| 41 | + return |
| 42 | + if not dist.is_initialized(): |
| 43 | + return |
| 44 | + world_size = dist.get_world_size() |
| 45 | + if world_size == 1: |
| 46 | + return |
| 47 | + dist.barrier() |
| 48 | + |
| 49 | + |
| 50 | +def all_gather(data): |
| 51 | + """ |
| 52 | + Run all_gather on arbitrary picklable data (not necessarily tensors) |
| 53 | + Args: |
| 54 | + data: any picklable object |
| 55 | + Returns: |
| 56 | + list[data]: list of data gathered from each rank |
| 57 | + """ |
| 58 | + world_size = get_world_size() |
| 59 | + if world_size == 1: |
| 60 | + return [data] |
| 61 | + |
| 62 | + # serialized to a Tensor |
| 63 | + origin_size = None |
| 64 | + if not isinstance(data, torch.Tensor): |
| 65 | + buffer = pickle.dumps(data) |
| 66 | + storage = torch.ByteStorage.from_buffer(buffer) |
| 67 | + tensor = torch.ByteTensor(storage).to("cuda") |
| 68 | + else: |
| 69 | + origin_size = data.size() |
| 70 | + tensor = data.reshape(-1) |
| 71 | + |
| 72 | + tensor_type = tensor.dtype |
| 73 | + |
| 74 | + # obtain Tensor size of each rank |
| 75 | + local_size = torch.LongTensor([tensor.numel()]).to("cuda") |
| 76 | + size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] |
| 77 | + dist.all_gather(size_list, local_size) |
| 78 | + size_list = [int(size.item()) for size in size_list] |
| 79 | + max_size = max(size_list) |
| 80 | + |
| 81 | + # receiving Tensor from all ranks |
| 82 | + # we pad the tensor because torch all_gather does not support |
| 83 | + # gathering tensors of different shapes |
| 84 | + tensor_list = [] |
| 85 | + for _ in size_list: |
| 86 | + tensor_list.append(torch.FloatTensor(size=(max_size,)).cuda().to(tensor_type)) |
| 87 | + if local_size != max_size: |
| 88 | + padding = torch.FloatTensor(size=(max_size - local_size,)).cuda().to(tensor_type) |
| 89 | + tensor = torch.cat((tensor, padding), dim=0) |
| 90 | + dist.all_gather(tensor_list, tensor) |
| 91 | + |
| 92 | + data_list = [] |
| 93 | + for size, tensor in zip(size_list, tensor_list): |
| 94 | + if origin_size is None: |
| 95 | + buffer = tensor.cpu().numpy().tobytes()[:size] |
| 96 | + data_list.append(pickle.loads(buffer)) |
| 97 | + else: |
| 98 | + buffer = tensor[:size] |
| 99 | + data_list.append(buffer) |
| 100 | + |
| 101 | + if origin_size is not None: |
| 102 | + new_shape = [-1] + list(origin_size[1:]) |
| 103 | + resized_list = [] |
| 104 | + for data in data_list: |
| 105 | + # suppose the difference of tensor size exist in first dimension |
| 106 | + data = data.reshape(new_shape) |
| 107 | + resized_list.append(data) |
| 108 | + |
| 109 | + return resized_list |
| 110 | + else: |
| 111 | + return data_list |
| 112 | + |
| 113 | + |
| 114 | +def reduce_dict(input_dict, average=True): |
| 115 | + """ |
| 116 | + Args: |
| 117 | + input_dict (dict): all the values will be reduced |
| 118 | + average (bool): whether to do average or sum |
| 119 | + Reduce the values in the dictionary from all processes so that process with rank |
| 120 | + 0 has the averaged results. Returns a dict with the same fields as |
| 121 | + input_dict, after reduction. |
| 122 | + """ |
| 123 | + world_size = get_world_size() |
| 124 | + if world_size < 2: |
| 125 | + return input_dict |
| 126 | + with torch.no_grad(): |
| 127 | + names = [] |
| 128 | + values = [] |
| 129 | + # sort the keys so that they are consistent across processes |
| 130 | + for k in sorted(input_dict.keys()): |
| 131 | + names.append(k) |
| 132 | + values.append(input_dict[k]) |
| 133 | + values = torch.stack(values, dim=0) |
| 134 | + dist.reduce(values, dst=0) |
| 135 | + if dist.get_rank() == 0 and average: |
| 136 | + # only main process gets accumulated, so only divide by |
| 137 | + # world_size in this case |
| 138 | + values /= world_size |
| 139 | + reduced_dict = {k: v for k, v in zip(names, values)} |
| 140 | + return reduced_dict |
| 141 | + |
| 142 | + |
| 143 | +def average_reduce_value(data): |
| 144 | + data_list = all_gather(data) |
| 145 | + return sum(data_list) / len(data_list) |
| 146 | + |
| 147 | + |
| 148 | +def all_reduce(data, op="sum", average=False): |
| 149 | + |
| 150 | + def op_map(op): |
| 151 | + op_dict = { |
| 152 | + "SUM": dist.ReduceOp.SUM, |
| 153 | + "MAX": dist.ReduceOp.MAX, |
| 154 | + "MIN": dist.ReduceOp.MIN, |
| 155 | + "PRODUCT": dist.ReduceOp.PRODUCT, |
| 156 | + } |
| 157 | + return op_dict[op] |
| 158 | + |
| 159 | + world_size = get_world_size() |
| 160 | + if world_size > 1: |
| 161 | + reduced_data = data.clone() |
| 162 | + dist.all_reduce(reduced_data, op=op_map(op.upper())) |
| 163 | + if average: |
| 164 | + assert op.upper() == 'SUM' |
| 165 | + return reduced_data / world_size |
| 166 | + else: |
| 167 | + return reduced_data |
| 168 | + return data |
| 169 | + |
| 170 | + |
| 171 | +@torch.no_grad() |
| 172 | +def concat_all_gather(tensor): |
| 173 | + """ |
| 174 | + Performs all_gather operation on the provided tensors. |
| 175 | + *** Warning ***: torch.distributed.all_gather has no gradient. |
| 176 | + """ |
| 177 | + tensors_gather = [torch.ones_like(tensor) |
| 178 | + for _ in range(torch.distributed.get_world_size())] |
| 179 | + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
| 180 | + |
| 181 | + output = torch.cat(tensors_gather, dim=0) |
| 182 | + return output |
0 commit comments