|
41 | 41 | ChatGLMv2Tokenizer, |
42 | 42 | Llama3Tokenizer, |
43 | 43 | LlamaTokenizer, |
44 | | - PretrainedModel, |
| 44 | + PretrainedConfig, |
45 | 45 | PretrainedTokenizer, |
46 | 46 | ) |
47 | 47 | from paddlenlp.trl import llm_utils |
@@ -245,11 +245,9 @@ def predict(self, input_texts: str | list[str], return_tokens=False): |
245 | 245 |
|
246 | 246 |
|
247 | 247 | class DygraphPredictor(BasePredictor): |
248 | | - def __init__( |
249 | | - self, config: PredictorArgument, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None |
250 | | - ): |
| 248 | + def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs): |
251 | 249 | super().__init__(config, tokenizer) |
252 | | - self.model = model |
| 250 | + self.model = kwargs.get("model", None) |
253 | 251 | if config.lora_path is not None: |
254 | 252 | lora_config = LoRAConfig.from_pretrained(config.lora_path) |
255 | 253 | dtype = lora_config.dtype |
@@ -326,7 +324,7 @@ def stream_predict(self, inputs: dict[str, paddle.Tensor]): |
326 | 324 |
|
327 | 325 |
|
328 | 326 | class StaticGraphPredictor(BasePredictor): |
329 | | - def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None): |
| 327 | + def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs): |
330 | 328 | super().__init__(config, tokenizer) |
331 | 329 |
|
332 | 330 | inference_config = paddle.inference.Config(self.config.model_name_or_path, self.config.model_prefix) |
@@ -623,14 +621,16 @@ def _preprocess(self, source): |
623 | 621 | return inputs |
624 | 622 |
|
625 | 623 |
|
626 | | -class StaticInferencePredictor(InferencePredictorMixin): |
| 624 | +class StaticGraphInferencePredictor(InferencePredictorMixin): |
627 | 625 | def __init__( |
628 | 626 | self, |
629 | 627 | config: PredictorArgument, |
630 | | - cache_kvs_shape: list[list[int]], |
631 | 628 | tokenizer: PretrainedTokenizer = None, |
| 629 | + **kwargs, |
632 | 630 | ): |
633 | | - self.cache_kvs_shape = cache_kvs_shape |
| 631 | + self.cache_kvs_shape = kwargs.get("cache_kvs_shape", None) |
| 632 | + if self.cache_kvs_shape is None: |
| 633 | + raise ValueError("cache_kvs_shape should be provided for StaticGraphInferencePredictor") |
634 | 634 | InferencePredictorMixin.__init__(self, config, tokenizer) |
635 | 635 |
|
636 | 636 | self.predictor = self._create_predictor(config) |
@@ -701,9 +701,12 @@ class DygraphInferencePredictor(InferencePredictorMixin): |
701 | 701 | def __init__( |
702 | 702 | self, |
703 | 703 | config: PredictorArgument, |
704 | | - model: PretrainedModel = None, |
705 | 704 | tokenizer: PretrainedTokenizer = None, |
| 705 | + **kwargs, |
706 | 706 | ): |
| 707 | + model = kwargs.get("model", None) |
| 708 | + if model is None: |
| 709 | + raise ValueError("model should be provided for DygraphInferencePredictor") |
707 | 710 | self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size, config.total_max_length) |
708 | 711 | InferencePredictorMixin.__init__(self, config, tokenizer) |
709 | 712 | self.model = model |
@@ -982,12 +985,10 @@ def _preprocess(self, input_text: list[str]): |
982 | 985 |
|
983 | 986 |
|
984 | 987 | class DygraphBlockInferencePredictor(BlockInferencePredictorMixin): |
985 | | - def __init__( |
986 | | - self, |
987 | | - config: PredictorArgument, |
988 | | - model: PretrainedModel = None, |
989 | | - tokenizer: PretrainedTokenizer = None, |
990 | | - ): |
| 988 | + def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs): |
| 989 | + model = kwargs.get("model", None) |
| 990 | + if model is None: |
| 991 | + raise ValueError("model should be provided for DygraphBlockInferencePredictor") |
991 | 992 | self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size) |
992 | 993 | BlockInferencePredictorMixin.__init__(self, config, tokenizer) |
993 | 994 |
|
@@ -1079,14 +1080,16 @@ def predict(self, input_texts: list[str], return_tokens=False): |
1079 | 1080 | return outputs |
1080 | 1081 |
|
1081 | 1082 |
|
1082 | | -class StaticBlockInferencePredictor(BlockInferencePredictorMixin): |
| 1083 | +class StaticGraphBlockInferencePredictor(BlockInferencePredictorMixin): |
1083 | 1084 | def __init__( |
1084 | 1085 | self, |
1085 | 1086 | config: PredictorArgument, |
1086 | | - cache_kvs_shape: list[list[int]], |
1087 | 1087 | tokenizer: PretrainedTokenizer = None, |
| 1088 | + **kwargs, |
1088 | 1089 | ): |
1089 | | - self.cache_kvs_shape = cache_kvs_shape |
| 1090 | + self.cache_kvs_shape = kwargs.get("cache_kvs_shape", None) |
| 1091 | + if self.cache_kvs_shape is None: |
| 1092 | + raise ValueError("cache_kvs_shape should be provided for StaticGraphBlockInferencePredictor") |
1090 | 1093 | BlockInferencePredictorMixin.__init__(self, config, tokenizer) |
1091 | 1094 |
|
1092 | 1095 | self._create_predictor(config) |
@@ -1224,21 +1227,71 @@ def predict(self, input_texts: list[str], return_tokens=False): |
1224 | 1227 | return outputs |
1225 | 1228 |
|
1226 | 1229 |
|
1227 | | -def get_ptq_multicards_num(directory): |
1228 | | - count = 0 |
1229 | | - if os.path.exists(directory): |
1230 | | - prefix = "act_scales_" |
1231 | | - for filename in os.listdir(directory): |
1232 | | - if filename.startswith(prefix): |
1233 | | - count += 1 |
1234 | | - return count |
| 1230 | +class AutoPredictor: |
| 1231 | + def __init__(self, *args, **kwargs): |
| 1232 | + raise EnvironmentError( |
| 1233 | + f"{self.__class__.__name__} is designed to be instantiated " |
| 1234 | + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path).`" |
| 1235 | + ) |
| 1236 | + |
| 1237 | + @classmethod |
| 1238 | + def create_predictor( |
| 1239 | + cls, |
| 1240 | + predictor_args: PredictorArgument, |
| 1241 | + config: PretrainedConfig, |
| 1242 | + model_args: ModelArgument, |
| 1243 | + tokenizer: PretrainedTokenizer = None, |
| 1244 | + **kwargs |
| 1245 | + ): |
| 1246 | + """ |
| 1247 | + Create a predictor |
| 1248 | +
|
| 1249 | + Args: |
| 1250 | + predictor_args (PredictorArgument): The predictor arguments. |
| 1251 | + config (PretrainedConfig): The model configuration. |
| 1252 | + model_args (ModelArgument): The model arguments. |
| 1253 | + tokenizer (PretrainedTokenizer): The tokenizer. |
| 1254 | + **kwargs: Additional keyword arguments. |
| 1255 | + Returns: |
| 1256 | + Predictor: The predictor. |
| 1257 | + """ |
| 1258 | + model = kwargs.pop("model", None) |
| 1259 | + cache_kvs_shape = None |
| 1260 | + |
| 1261 | + # static or dynamic |
| 1262 | + execute_mode = "Dygraph" if predictor_args.mode == "dynamic" else "StaticGraph" |
| 1263 | + |
| 1264 | + # infer/ no infer |
| 1265 | + if predictor_args.inference_model: |
| 1266 | + # block/no block |
| 1267 | + if predictor_args.block_attn: |
| 1268 | + attn_type = "Block" |
| 1269 | + else: |
| 1270 | + attn_type = "" |
| 1271 | + inference_mode = f"{attn_type}Inference" |
| 1272 | + |
| 1273 | + if predictor_args.mode == "static": |
| 1274 | + cache_kvs_shape = model.get_cache_kvs_shape( |
| 1275 | + config, predictor_args.batch_size, predictor_args.total_max_length |
| 1276 | + ) |
| 1277 | + else: |
| 1278 | + inference_mode = "" |
| 1279 | + |
| 1280 | + predictor_class_name = execute_mode + inference_mode + "Predictor" |
| 1281 | + |
| 1282 | + import_class = sys.modules[__name__] |
| 1283 | + |
| 1284 | + # import class |
| 1285 | + predictor_class = getattr(import_class, predictor_class_name) |
| 1286 | + |
| 1287 | + # instance |
| 1288 | + predictor = predictor_class(predictor_args, tokenizer=tokenizer, model=model, cache_kvs_shape=cache_kvs_shape) |
| 1289 | + return predictor |
1235 | 1290 |
|
1236 | 1291 |
|
1237 | 1292 | def create_predictor( |
1238 | 1293 | predictor_args: PredictorArgument, |
1239 | 1294 | model_args: ModelArgument, |
1240 | | - tensor_parallel_degree: int = 1, |
1241 | | - tensor_parallel_rank: int = 0, |
1242 | 1295 | ): |
1243 | 1296 | tokenizer = AutoTokenizer.from_pretrained( |
1244 | 1297 | predictor_args.model_name_or_path, |
@@ -1272,9 +1325,23 @@ def create_predictor( |
1272 | 1325 | predictor_args.temperature = 1.0 |
1273 | 1326 |
|
1274 | 1327 | tensor_parallel_rank, tensor_parallel_degree = llm_utils.init_dist_env() |
1275 | | - if not predictor_args.inference_model: |
1276 | | - tokenizer.padding_side = "left" |
| 1328 | + |
| 1329 | + model = None |
| 1330 | + |
| 1331 | + # model loading |
| 1332 | + if predictor_args.inference_model: |
| 1333 | + model = AutoInferenceModelForCausalLM.from_pretrained( |
| 1334 | + predictor_args.model_name_or_path, |
| 1335 | + config=config, |
| 1336 | + predictor_args=predictor_args, |
| 1337 | + model_args=model_args, |
| 1338 | + dtype=predictor_args.dtype, |
| 1339 | + tensor_parallel_degree=tensor_parallel_degree, |
| 1340 | + tensor_parallel_rank=tensor_parallel_rank, |
| 1341 | + ) |
| 1342 | + else: |
1277 | 1343 | if predictor_args.mode == "dynamic": |
| 1344 | + # model import (gpt-3,ernie) or AutoModel |
1278 | 1345 | if model_args.model_type == "gpt-3": |
1279 | 1346 | sys.path.append("./gpt-3") |
1280 | 1347 | from modeling import GPTForCausalLM |
@@ -1309,47 +1376,7 @@ def create_predictor( |
1309 | 1376 | tensor_parallel_output=False, |
1310 | 1377 | ) |
1311 | 1378 |
|
1312 | | - predictor = DygraphPredictor(predictor_args, model=model, tokenizer=tokenizer) |
1313 | | - elif predictor_args.mode == "static": |
1314 | | - predictor = StaticGraphPredictor(predictor_args, tokenizer=tokenizer) |
1315 | | - else: |
1316 | | - raise ValueError("the `mode` should be one of [dynamic, static]") |
1317 | | - else: |
1318 | | - if predictor_args.mode == "dynamic": |
1319 | | - model = AutoInferenceModelForCausalLM.from_pretrained( |
1320 | | - predictor_args.model_name_or_path, |
1321 | | - config=config, |
1322 | | - predictor_args=predictor_args, |
1323 | | - model_args=model_args, |
1324 | | - dtype=predictor_args.dtype, |
1325 | | - tensor_parallel_degree=tensor_parallel_degree, |
1326 | | - tensor_parallel_rank=tensor_parallel_rank, |
1327 | | - ) |
1328 | | - model.eval() |
1329 | | - if predictor_args.block_attn: |
1330 | | - predictor = DygraphBlockInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) |
1331 | | - else: |
1332 | | - predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) |
1333 | | - |
1334 | | - elif predictor_args.mode == "static": |
1335 | | - model = AutoInferenceModelForCausalLM.from_pretrained( |
1336 | | - predictor_args.model_name_or_path, |
1337 | | - config=config, |
1338 | | - predictor_args=predictor_args, |
1339 | | - model_args=model_args, |
1340 | | - dtype=predictor_args.dtype, |
1341 | | - tensor_parallel_degree=tensor_parallel_degree, |
1342 | | - tensor_parallel_rank=tensor_parallel_rank, |
1343 | | - ) |
1344 | | - cache_kvs_shape = model.get_cache_kvs_shape( |
1345 | | - config, predictor_args.batch_size, predictor_args.total_max_length |
1346 | | - ) |
1347 | | - if predictor_args.block_attn: |
1348 | | - predictor = StaticBlockInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) |
1349 | | - else: |
1350 | | - predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) |
1351 | | - else: |
1352 | | - raise ValueError("the `mode` should be one of [dynamic, static]") |
| 1379 | + predictor = AutoPredictor.create_predictor(predictor_args, config, model_args, tokenizer, model=model) |
1353 | 1380 |
|
1354 | 1381 | return predictor |
1355 | 1382 |
|
|
0 commit comments