Skip to content

Commit fcafcb2

Browse files
fixed (#1384)
1 parent 3e024e9 commit fcafcb2

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

swift/llm/rlhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
6363

6464
if args.device_max_memory:
6565
n_gpu = torch.cuda.device_count()
66-
assert len(args.device_max_memory) == n_gpu / local_world_size
66+
assert len(args.device_max_memory) == n_gpu // local_world_size
6767
model_kwargs['max_memory'] = {
6868
i: mem
6969
for i, mem in zip(list(range(max(local_rank, 0), n_gpu, local_world_size)), args.device_max_memory)

swift/llm/sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
6262

6363
if args.device_max_memory:
6464
n_gpu = torch.cuda.device_count()
65-
assert len(args.device_max_memory) == n_gpu / local_world_size
65+
assert len(args.device_max_memory) == n_gpu // local_world_size
6666
model_kwargs['max_memory'] = {
6767
i: mem
6868
for i, mem in zip(list(range(max(local_rank, 0), n_gpu, local_world_size)), args.device_max_memory)

0 commit comments

Comments
 (0)