Skip to content

Commit 97ecb77

Browse files
authored
Merge pull request #20 from TransferQueue/dev
Merge Dev 0921
2 parents 84daa96 + 70c1473 commit 97ecb77

File tree

22 files changed

+446
-2889
lines changed

22 files changed

+446
-2889
lines changed

recipe/simple_use_case/async_demo.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,13 @@
1616
sys.path.append(str(parent_dir))
1717
from transfer_queue.data_system import AsyncTransferQueueClient, TransferQueueController, \
1818
TransferQueueStorageSimpleUnit, process_zmq_server_info
19+
from transfer_queue.utils.utils import get_placement_group
1920

2021
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
2122
logger = logging.getLogger(__name__)
2223

2324
ray.init(runtime_env={"env_vars": {"RAY_DEBUG": "1", "RAY_DEDUP_LOGS": "0"}})
2425

25-
"""
26-
同步的fit函数
27-
28-
"""
29-
3026

3127
def compute_old_log_prob(data1, data2):
3228
time.sleep(3)
@@ -151,9 +147,13 @@ def _initialize_data_system(self):
151147
# 1. 初始化TransferQueueStorage
152148
total_storage_size = (self.config.global_batch_size * self.config.num_global_batch)
153149
self.data_system_storage_units = {}
150+
storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1)
154151
for storage_unit_rank in range(self.config.num_data_storage_units):
155152
# TransferQueueStorage通过Ray拉起,是一个ray.remote修饰的类
156-
storage_node = TransferQueueStorageSimpleUnit.remote(
153+
storage_node = TransferQueueStorageSimpleUnit.options(
154+
placement_group=storage_placement_group,
155+
placement_group_bundle_index=storage_unit_rank
156+
).remote(
157157
storage_size=math.ceil(total_storage_size / self.config.num_data_storage_units)
158158
)
159159
self.data_system_storage_units[storage_unit_rank] = storage_node
@@ -162,8 +162,12 @@ def _initialize_data_system(self):
162162
# 2. 初始化TransferQueueController
163163
# 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务
164164
self.data_system_controllers = {}
165+
controller_placement_group = get_placement_group(self.config.num_data_controllers, num_cpus_per_actor=1)
165166
for controller_rank in range(self.config.num_data_controllers):
166-
self.data_system_controllers[controller_rank] = TransferQueueController.remote(
167+
self.data_system_controllers[controller_rank] = TransferQueueController.options(
168+
placement_group=controller_placement_group,
169+
placement_group_bundle_index=controller_rank
170+
).remote(
167171
num_storage_units=self.config.num_data_storage_units,
168172
global_batch_size=self.config.global_batch_size,
169173
num_global_batch=self.config.num_global_batch,
@@ -193,17 +197,16 @@ def fit(self):
193197
for epoch in range(1):
194198
train_dataloader = 1
195199
for step in range(train_dataloader):
196-
input_ids = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])) * (step + 1)
197-
prompt_batch = TensorDict({"input_ids": input_ids}, batch_size=input_ids.size(0))
200+
input_ids = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]])) * (step + 1)
201+
prompt_batch = TensorDict({"input_ids": input_ids, "attention_mask": input_ids}, batch_size=input_ids.size(0))
198202

199-
asyncio.run(self.data_system_client.async_put(data=prompt_batch, data_fields=["input_ids"],
200-
global_step=step))
203+
asyncio.run(self.data_system_client.async_put(data=prompt_batch, global_step=step))
201204

202205
logger.info("demo put prompts ok! ")
203206
time.sleep(5)
204207

205208
prompt_meta = asyncio.run(self.data_system_client.async_get_meta(
206-
data_fields=['input_ids'],
209+
data_fields=['input_ids', 'attention_mask'],
207210
batch_size=self.config.global_batch_size,
208211
global_step=step,
209212
get_n_samples=False,
@@ -218,7 +221,7 @@ def fit(self):
218221
self.async_rollout_manager.generate_sequences(prompt_meta)
219222

220223
log_prob_meta = asyncio.run(self.data_system_client.async_get_meta(
221-
data_fields=['input_ids', 'generate_sequences_ids'],
224+
data_fields=['input_ids', 'attention_mask', 'generate_sequences_ids'],
222225
batch_size=self.config.global_batch_size,
223226
global_step=step,
224227
get_n_samples=False,
@@ -238,7 +241,7 @@ def fit(self):
238241

239242
if __name__ == "__main__":
240243
config_str = """
241-
global_batch_size: 4
244+
global_batch_size: 6
242245
num_global_batch: 1
243246
num_data_storage_units: 2
244247
num_data_controllers: 1

recipe/simple_use_case/sync_demo.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
parent_dir = Path(__file__).resolve().parent.parent.parent
1414
sys.path.append(str(parent_dir))
1515
from transfer_queue.data_system import TransferQueueController, TransferQueueStorageSimpleUnit, process_zmq_server_info
16+
from transfer_queue.utils.utils import get_placement_group
1617

1718

1819
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -25,19 +26,27 @@ def initialize_data_system(config):
2526
# 1. 初始化TransferQueueStorage
2627
total_storage_size = (config.global_batch_size * config.num_global_batch)
2728
data_system_storage_units = {}
29+
storage_placement_group = get_placement_group(config.num_data_storage_units, num_cpus_per_actor=1)
2830
for storage_unit_rank in range(config.num_data_storage_units):
2931
# TransferQueueStorage通过Ray拉起,是一个ray.remote修饰的类
30-
storage_node = TransferQueueStorageSimpleUnit.remote(
32+
storage_node = TransferQueueStorageSimpleUnit.options(
33+
placement_group=storage_placement_group,
34+
placement_group_bundle_index=storage_unit_rank
35+
).remote(
3136
storage_size=math.ceil(total_storage_size / config.num_data_storage_units)
3237
)
3338
data_system_storage_units[storage_unit_rank] = storage_node
3439
logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
3540

3641
# 2. 初始化TransferQueueController
3742
# 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务
43+
controller_placement_group = get_placement_group(config.num_data_controllers, num_cpus_per_actor=1)
3844
data_system_controllers = {}
3945
for controller_rank in range(config.num_data_controllers):
40-
data_system_controllers[controller_rank] = TransferQueueController.remote(
46+
data_system_controllers[controller_rank] = TransferQueueController.options(
47+
placement_group=controller_placement_group,
48+
placement_group_bundle_index=controller_rank
49+
).remote(
4150
num_storage_units=config.num_data_storage_units,
4251
global_batch_size=config.global_batch_size,
4352
num_global_batch=config.num_global_batch,
@@ -115,15 +124,15 @@ def fit(config, data_system_client):
115124
for epoch in range(2):
116125
train_dataloader = 2
117126
for step in range(train_dataloader):
118-
input_ids = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])) * (step + 1)
119-
prompt_batch = TensorDict({"input_ids": input_ids}, batch_size=input_ids.size(0))
127+
input_ids = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]])) * (step + 1)
128+
prompt_batch = TensorDict({"input_ids": input_ids, "attention_mask": input_ids}, batch_size=input_ids.size(0))
120129

121-
data_system_client.put(data=prompt_batch, data_fields=["input_ids"], global_step=step)
130+
data_system_client.put(data=prompt_batch, global_step=step)
122131
logger.info("demo put prompts ok! ")
123132
time.sleep(5)
124133

125134
prompt_meta = data_system_client.get_meta(
126-
data_fields=['input_ids'],
135+
data_fields=['input_ids', 'attention_mask'],
127136
batch_size=config.global_batch_size,
128137
global_step=step,
129138
get_n_samples=False,
@@ -136,7 +145,7 @@ def fit(config, data_system_client):
136145
actor_rollout_wg_generate_sequences(prompt_meta, data_system_client)
137146

138147
log_prob_meta = data_system_client.get_meta(
139-
data_fields=['input_ids', 'generate_sequences_ids'],
148+
data_fields=['input_ids', 'attention_mask', 'generate_sequences_ids'],
140149
batch_size=config.global_batch_size,
141150
global_step=0,
142151
get_n_samples=False,
@@ -166,7 +175,7 @@ def main(config):
166175

167176
if __name__ == "__main__":
168177
config_str = """
169-
global_batch_size: 4
178+
global_batch_size: 6
170179
num_global_batch: 1
171180
num_data_storage_units: 2
172181
num_data_controllers: 1

recipe/verl_use_case/main.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)