Skip to content

Commit 3e024e9

Browse files
fix (#1383)
1 parent d5f0a4d commit 3e024e9

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

swift/llm/rlhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
6666
assert len(args.device_max_memory) == n_gpu / local_world_size
6767
model_kwargs['max_memory'] = {
6868
i: mem
69-
for i, mem in zip(list(range(local_rank, n_gpu, local_world_size)), args.device_max_memory)
69+
for i, mem in zip(list(range(max(local_rank, 0), n_gpu, local_world_size)), args.device_max_memory)
7070
}
7171

7272
# quantization

swift/llm/sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
6565
assert len(args.device_max_memory) == n_gpu / local_world_size
6666
model_kwargs['max_memory'] = {
6767
i: mem
68-
for i, mem in zip(list(range(local_rank, n_gpu, local_world_size)), args.device_max_memory)
68+
for i, mem in zip(list(range(max(local_rank, 0), n_gpu, local_world_size)), args.device_max_memory)
6969
}
7070

7171
if args.quant_method == 'hqq':

0 commit comments

Comments
 (0)