Skip to content

Commit bdf9a00

Browse files
[Fix/Inference] Add unsupported auto-policy error message (#5730)
* [fix] auto policy error message * trivial
1 parent 283c407 commit bdf9a00

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

colossalai/inference/core/engine.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22
from itertools import count
3-
from typing import Dict, List, Optional, Tuple, Union
3+
from typing import Dict, List, Optional, Tuple, Type, Union
44

55
import numpy as np
66
import torch
@@ -64,7 +64,7 @@ def __init__(
6464
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
6565
inference_config: InferenceConfig,
6666
verbose: bool = False,
67-
model_policy: Policy = None,
67+
model_policy: Union[Policy, Type[Policy]] = None,
6868
) -> None:
6969
self.inference_config = inference_config
7070
self.dtype = inference_config.dtype
@@ -105,7 +105,7 @@ def __init__(
105105

106106
self._verify_args()
107107

108-
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
108+
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
109109
"""
110110
Shard model or/and Load weight
111111
@@ -150,11 +150,17 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
150150
)
151151

152152
if model_policy is None:
153-
if self.inference_config.pad_input:
154-
model_type = "padding_" + self.model_config.model_type
155-
else:
156-
model_type = "nopadding_" + self.model_config.model_type
157-
model_policy = model_policy_map[model_type]()
153+
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
154+
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
155+
model_policy = model_policy_map.get(model_policy_key)
156+
157+
if not isinstance(model_policy, Policy):
158+
try:
159+
model_policy = model_policy()
160+
except Exception as e:
161+
raise ValueError(f"Unable to instantiate model policy: {e}")
162+
163+
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
158164
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
159165
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
160166

0 commit comments

Comments
 (0)