|
1 | 1 | import time
|
2 | 2 | from itertools import count
|
3 |
| -from typing import Dict, List, Optional, Tuple, Union |
| 3 | +from typing import Dict, List, Optional, Tuple, Type, Union |
4 | 4 |
|
5 | 5 | import numpy as np
|
6 | 6 | import torch
|
@@ -64,7 +64,7 @@ def __init__(
|
64 | 64 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
65 | 65 | inference_config: InferenceConfig,
|
66 | 66 | verbose: bool = False,
|
67 |
| - model_policy: Policy = None, |
| 67 | + model_policy: Union[Policy, Type[Policy]] = None, |
68 | 68 | ) -> None:
|
69 | 69 | self.inference_config = inference_config
|
70 | 70 | self.dtype = inference_config.dtype
|
@@ -105,7 +105,7 @@ def __init__(
|
105 | 105 |
|
106 | 106 | self._verify_args()
|
107 | 107 |
|
108 |
| - def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): |
| 108 | + def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): |
109 | 109 | """
|
110 | 110 | Shard model or/and Load weight
|
111 | 111 |
|
@@ -150,11 +150,17 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
|
150 | 150 | )
|
151 | 151 |
|
152 | 152 | if model_policy is None:
|
153 |
| - if self.inference_config.pad_input: |
154 |
| - model_type = "padding_" + self.model_config.model_type |
155 |
| - else: |
156 |
| - model_type = "nopadding_" + self.model_config.model_type |
157 |
| - model_policy = model_policy_map[model_type]() |
| 153 | + prefix = "nopadding" if not self.inference_config.pad_input else "padding" |
| 154 | + model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" |
| 155 | + model_policy = model_policy_map.get(model_policy_key) |
| 156 | + |
| 157 | + if not isinstance(model_policy, Policy): |
| 158 | + try: |
| 159 | + model_policy = model_policy() |
| 160 | + except Exception as e: |
| 161 | + raise ValueError(f"Unable to instantiate model policy: {e}") |
| 162 | + |
| 163 | + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" |
158 | 164 | pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
159 | 165 | tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
160 | 166 |
|
|
0 commit comments