|
1 | 1 | from typing import Any, Dict
|
2 |
| - |
| 2 | +import copy |
| 3 | +import ray |
3 | 4 | import ray.util.collective as cc
|
4 | 5 | import torch
|
5 | 6 | import torch.distributed.distributed_c10d as c10d
|
@@ -30,28 +31,122 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
|
30 | 31 | obj = c10d._tensor_to_object(obj, size_tensor.item())
|
31 | 32 | return obj
|
32 | 33 |
|
33 |
| - |
34 | 34 | def ray_broadcast_tensor_dict(
|
35 |
| - tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default" |
| 35 | + tensor_dict: Dict[str, torch.Tensor], |
| 36 | + src: int = 0, |
| 37 | + device=None, |
| 38 | + group_name: str = "default", |
| 39 | + backend: str = "nccl", |
| 40 | + offload_to_cpu: bool = False, |
| 41 | + pin_memory: bool = False, |
36 | 42 | ) -> Dict[str, torch.Tensor]:
|
37 | 43 | rank = cc.get_rank(group_name)
|
| 44 | + if tensor_dict is None: |
| 45 | + tensor_dict = {} |
38 | 46 | if rank == src:
|
39 | 47 | metadata = []
|
40 | 48 | for k, v in tensor_dict.items():
|
41 | 49 | metadata.append((k, v.shape, v.dtype))
|
42 | 50 | else:
|
43 | 51 | metadata = None
|
44 | 52 | metadata = ray_broadcast_object(metadata, src, device, group_name)
|
45 |
| - if rank != src: |
46 |
| - out_dict = {} |
47 | 53 | for k, shape, dtype in metadata:
|
48 | 54 | if rank == src:
|
49 |
| - tensor = tensor_dict[k] |
| 55 | + if offload_to_cpu: |
| 56 | + tensor = tensor_dict[k].to(device) |
| 57 | + else: |
| 58 | + tensor = tensor_dict[k] |
50 | 59 | else:
|
51 |
| - tensor = torch.empty(shape, dtype=dtype, device=device) |
| 60 | + tensor = tensor_dict.get(k, torch.zeros(shape, dtype=dtype, device=device, pin_memory=pin_memory)) |
| 61 | + if backend == "gloo" and dtype == torch.bfloat16: |
| 62 | + # Gloo does not support bfloat16, convert to float16 |
| 63 | + tensor = tensor.view(torch.float16) |
52 | 64 | cc.broadcast(tensor, src, group_name)
|
| 65 | + if backend == "gloo" and dtype == torch.bfloat16: |
| 66 | + # Convert back to bfloat16 if it was converted to float16 |
| 67 | + tensor = tensor.view(torch.bfloat16) |
53 | 68 | if rank != src:
|
54 |
| - out_dict[k] = tensor |
55 |
| - if rank == src: |
56 |
| - out_dict = tensor_dict |
57 |
| - return out_dict |
| 69 | + if offload_to_cpu: |
| 70 | + tensor_dict[k] = tensor.cpu() |
| 71 | + else: |
| 72 | + tensor_dict[k] = tensor |
| 73 | + return tensor_dict |
| 74 | + |
| 75 | + |
| 76 | +@ray.remote |
| 77 | +class SharedVariableActor: |
| 78 | + def __init__(self, number_of_readers: int = 0, buffer_size_limit: int = 1000): |
| 79 | + self.data_queue = [] |
| 80 | + self.data_uid = 0 |
| 81 | + self.number_of_readers = number_of_readers |
| 82 | + self.queue_size = 0 |
| 83 | + self.signals = {} |
| 84 | + self.process_locks = {} |
| 85 | + self.signal_procs_meet_count = {} |
| 86 | + self.buffer_size_limit = buffer_size_limit |
| 87 | + |
| 88 | + def pickup_rollout_task(self, num_tasks: int): |
| 89 | + """ |
| 90 | + use queue size to control whether producers should generating new rollouts or wait |
| 91 | + for consumer to consumer more data. if queue size is less than threshold, |
| 92 | + it means consumer is consuming data fast enough, so producers can generate new rollouts. |
| 93 | + if queue size is greater than threshold, it means consumer is consuming data slowly, |
| 94 | + so producers should wait for consumer to consume more data. |
| 95 | +
|
| 96 | + Any free producer can pick up the task to generate rollout then increase the queued_data_size |
| 97 | + to prevent other producer to pick up the task redundantly, Note it is not the real |
| 98 | + queue length as data may still be generating |
| 99 | + """ |
| 100 | + ret = False |
| 101 | + if self.queue_size < self.buffer_size_limit: |
| 102 | + ret = True |
| 103 | + self.queue_size += num_tasks |
| 104 | + return ret |
| 105 | + |
| 106 | + def append_data(self, data): |
| 107 | + self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count] |
| 108 | + self.data_uid += 1 |
| 109 | + return True |
| 110 | + |
| 111 | + def get_data(self, data_uid: int): |
| 112 | + # for multi-process data reading |
| 113 | + if not self.data_queue: |
| 114 | + # no data in the queue, return None |
| 115 | + return None |
| 116 | + to_pop_index = None |
| 117 | + ret = None |
| 118 | + for i, (uid, data, access_count) in enumerate(self.data_queue): |
| 119 | + if uid == data_uid: |
| 120 | + # found the data with the given uid |
| 121 | + self.data_queue[i][2] += 1 |
| 122 | + ret = copy.deepcopy(data) |
| 123 | + if self.data_queue[i][2] == self.number_of_readers: |
| 124 | + to_pop_index = i |
| 125 | + break |
| 126 | + if to_pop_index is not None: |
| 127 | + # remove the data from the queue if it has been accessed by all readers |
| 128 | + self.data_queue.pop(to_pop_index) |
| 129 | + self.queue_size -= data["input_ids"].size(0) |
| 130 | + return ret |
| 131 | + |
| 132 | + def acquire_process_lock(self, key: str): |
| 133 | + # atomic lock for process |
| 134 | + if key not in self.process_locks: |
| 135 | + self.process_locks[key] = 1 # locked |
| 136 | + return 0 |
| 137 | + if self.process_locks[key] == 0: |
| 138 | + self.process_locks[key] = 1 # lock the process |
| 139 | + return 0 |
| 140 | + else: |
| 141 | + return 1 |
| 142 | + |
| 143 | + def release_process_lock(self, key: str): |
| 144 | + # atomic unlock for process |
| 145 | + assert self.process_locks.get(key, 0) == 1, f"Releasing a process lock {key} that is not locked." |
| 146 | + self.process_locks[key] = 0 |
| 147 | + |
| 148 | + def set_signal(self, key: str, signal: str): |
| 149 | + self.signals[key] = signal |
| 150 | + |
| 151 | + def get_signal(self): |
| 152 | + return self.signals |
0 commit comments