1616sys .path .append (str (parent_dir ))
1717from transfer_queue .data_system import AsyncTransferQueueClient , TransferQueueController , \
1818 TransferQueueStorageSimpleUnit , process_zmq_server_info
19+ from transfer_queue .utils .utils import get_placement_group
1920
2021logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' )
2122logger = logging .getLogger (__name__ )
2223
2324ray .init (runtime_env = {"env_vars" : {"RAY_DEBUG" : "1" , "RAY_DEDUP_LOGS" : "0" }})
2425
25- """
26- 同步的fit函数
27-
28- """
29-
3026
3127def compute_old_log_prob (data1 , data2 ):
3228 time .sleep (3 )
@@ -151,9 +147,13 @@ def _initialize_data_system(self):
151147 # 1. 初始化TransferQueueStorage
152148 total_storage_size = (self .config .global_batch_size * self .config .num_global_batch )
153149 self .data_system_storage_units = {}
150+ storage_placement_group = get_placement_group (self .config .num_data_storage_units , num_cpus_per_actor = 1 )
154151 for storage_unit_rank in range (self .config .num_data_storage_units ):
155152 # TransferQueueStorage通过Ray拉起,是一个ray.remote修饰的类
156- storage_node = TransferQueueStorageSimpleUnit .remote (
153+ storage_node = TransferQueueStorageSimpleUnit .options (
154+ placement_group = storage_placement_group ,
155+ placement_group_bundle_index = storage_unit_rank
156+ ).remote (
157157 storage_size = math .ceil (total_storage_size / self .config .num_data_storage_units )
158158 )
159159 self .data_system_storage_units [storage_unit_rank ] = storage_node
@@ -162,8 +162,12 @@ def _initialize_data_system(self):
162162 # 2. 初始化TransferQueueController
163163 # 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务
164164 self .data_system_controllers = {}
165+ controller_placement_group = get_placement_group (self .config .num_data_controllers , num_cpus_per_actor = 1 )
165166 for controller_rank in range (self .config .num_data_controllers ):
166- self .data_system_controllers [controller_rank ] = TransferQueueController .remote (
167+ self .data_system_controllers [controller_rank ] = TransferQueueController .options (
168+ placement_group = controller_placement_group ,
169+ placement_group_bundle_index = controller_rank
170+ ).remote (
167171 num_storage_units = self .config .num_data_storage_units ,
168172 global_batch_size = self .config .global_batch_size ,
169173 num_global_batch = self .config .num_global_batch ,
@@ -193,17 +197,16 @@ def fit(self):
193197 for epoch in range (1 ):
194198 train_dataloader = 1
195199 for step in range (train_dataloader ):
196- input_ids = (torch .tensor ([[1 , 2 ], [3 , 4 ], [5 , 6 ], [7 , 8 ]])) * (step + 1 )
197- prompt_batch = TensorDict ({"input_ids" : input_ids }, batch_size = input_ids .size (0 ))
200+ input_ids = (torch .tensor ([[1 , 2 ], [3 , 4 ], [5 , 6 ], [7 , 8 ], [ 10 , 11 ], [ 100 , 111 ] ])) * (step + 1 )
201+ prompt_batch = TensorDict ({"input_ids" : input_ids , "attention_mask" : input_ids }, batch_size = input_ids .size (0 ))
198202
199- asyncio .run (self .data_system_client .async_put (data = prompt_batch , data_fields = ["input_ids" ],
200- global_step = step ))
203+ asyncio .run (self .data_system_client .async_put (data = prompt_batch , global_step = step ))
201204
202205 logger .info ("demo put prompts ok! " )
203206 time .sleep (5 )
204207
205208 prompt_meta = asyncio .run (self .data_system_client .async_get_meta (
206- data_fields = ['input_ids' ],
209+ data_fields = ['input_ids' , 'attention_mask' ],
207210 batch_size = self .config .global_batch_size ,
208211 global_step = step ,
209212 get_n_samples = False ,
@@ -218,7 +221,7 @@ def fit(self):
218221 self .async_rollout_manager .generate_sequences (prompt_meta )
219222
220223 log_prob_meta = asyncio .run (self .data_system_client .async_get_meta (
221- data_fields = ['input_ids' , 'generate_sequences_ids' ],
224+ data_fields = ['input_ids' , 'attention_mask' , ' generate_sequences_ids' ],
222225 batch_size = self .config .global_batch_size ,
223226 global_step = step ,
224227 get_n_samples = False ,
@@ -238,7 +241,7 @@ def fit(self):
238241
239242if __name__ == "__main__" :
240243 config_str = """
241- global_batch_size: 4
244+ global_batch_size: 6
242245 num_global_batch: 1
243246 num_data_storage_units: 2
244247 num_data_controllers: 1
0 commit comments