Skip to content

Commit 69e3c69

Browse files
authored
fix max_memory (#3347)
1 parent 6aae524 commit 69e3c69

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

swift/llm/argument/base_args/model_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import ast
23
import math
34
import os
45
from dataclasses import dataclass, field
@@ -80,6 +81,11 @@ def _init_device_map(self):
8081
self.device_map[k] += local_rank
8182

8283
def _init_max_memory(self):
84+
if isinstance(self.max_memory, str):
85+
try:
86+
self.max_memory = ast.literal_eval(self.max_memory)
87+
except Exception:
88+
pass
8389
self.max_memory = self.parse_to_dict(self.max_memory)
8490
# compat mp&ddp
8591
_, local_rank, _, local_world_size = get_dist_setting()

swift/llm/model/register.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def get_model_tokenizer(
486486
# model kwargs
487487
model_type: Optional[str] = None,
488488
quantization_config=None,
489-
max_memory: Optional[List[str]] = None,
489+
max_memory: Union[str, Dict[str, Any]] = None,
490490
attn_impl: Literal['flash_attn', 'sdpa', 'eager', None] = None,
491491
rope_scaling: Optional[Dict[str, Any]] = None,
492492
automodel_class=None,

0 commit comments

Comments
 (0)