Skip to content

Commit deb3c75

Browse files
楚财峯回
authored andcommitted
PullRequest: 998 新增Off-Policy配置文件并调整相关训练参数
Merge branch chucai.dzq/align-offpolicy of [email protected]:inclusionAI/AReaL.git into asystem/gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/998 Reviewed-by: 峯回 <[email protected]> * align off policy
1 parent 6876b63 commit deb3c75

File tree

10 files changed

+1211
-374
lines changed

10 files changed

+1211
-374
lines changed

areal/api/alloc_mode.py

Lines changed: 611 additions & 313 deletions
Large diffs are not rendered by default.

areal/controller/train_controller.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,19 +249,20 @@ async def _async_custom_function_call(self, method: str, *args, **kwargs):
249249

250250
def _dispatch_inputs(self, *args, **kwargs):
251251
# Find and split DistributedBatch arguments
252+
rebalance = kwargs.pop("rebalance", True)
252253
split_args = []
253254
for arg in args:
254255
if isinstance(arg, DistributedBatch):
255256
# Split across DP groups
256-
split_args.append(self._align_batches_with_dp(arg, rebalance=True))
257+
split_args.append(self._align_batches_with_dp(arg, rebalance=rebalance))
257258
else:
258259
# Replicate to all DP heads
259260
split_args.append([arg] * self.parallel_strategy.dp_size)
260261

261262
split_kwargs = {}
262263
for k, v in kwargs.items():
263264
if isinstance(v, DistributedBatch):
264-
split_kwargs[k] = self._align_batches_with_dp(v, rebalance=True)
265+
split_kwargs[k] = self._align_batches_with_dp(v, rebalance=rebalance)
265266
else:
266267
split_kwargs[k] = [v] * self.parallel_strategy.dp_size
267268
return split_args, split_kwargs
@@ -633,7 +634,7 @@ def train_batch(
633634
gradient norm, etc.
634635
"""
635636
return self._custom_function_call(
636-
"train_batch", input_, loss_fn, loss_weight_fn
637+
"train_batch", input_, loss_fn, loss_weight_fn, rebalance=True
637638
)
638639

639640
def eval_batch(

0 commit comments

Comments
 (0)