|
41 | 41 | ChatGLMv2Tokenizer,
|
42 | 42 | Llama3Tokenizer,
|
43 | 43 | LlamaTokenizer,
|
44 |
| - PretrainedModel, |
| 44 | + PretrainedConfig, |
45 | 45 | PretrainedTokenizer,
|
46 | 46 | )
|
47 | 47 | from paddlenlp.trl import llm_utils
|
@@ -245,11 +245,9 @@ def predict(self, input_texts: str | list[str], return_tokens=False):
|
245 | 245 |
|
246 | 246 |
|
247 | 247 | class DygraphPredictor(BasePredictor):
|
248 |
| - def __init__( |
249 |
| - self, config: PredictorArgument, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None |
250 |
| - ): |
| 248 | + def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs): |
251 | 249 | super().__init__(config, tokenizer)
|
252 |
| - self.model = model |
| 250 | + self.model = kwargs.get("model", None) |
253 | 251 | if config.lora_path is not None:
|
254 | 252 | lora_config = LoRAConfig.from_pretrained(config.lora_path)
|
255 | 253 | dtype = lora_config.dtype
|
@@ -326,7 +324,7 @@ def stream_predict(self, inputs: dict[str, paddle.Tensor]):
|
326 | 324 |
|
327 | 325 |
|
328 | 326 | class StaticGraphPredictor(BasePredictor):
|
329 |
| - def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None): |
| 327 | + def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs): |
330 | 328 | super().__init__(config, tokenizer)
|
331 | 329 |
|
332 | 330 | inference_config = paddle.inference.Config(self.config.model_name_or_path, self.config.model_prefix)
|
@@ -623,14 +621,16 @@ def _preprocess(self, source):
|
623 | 621 | return inputs
|
624 | 622 |
|
625 | 623 |
|
626 |
| -class StaticInferencePredictor(InferencePredictorMixin): |
| 624 | +class StaticGraphInferencePredictor(InferencePredictorMixin): |
627 | 625 | def __init__(
|
628 | 626 | self,
|
629 | 627 | config: PredictorArgument,
|
630 |
| - cache_kvs_shape: list[list[int]], |
631 | 628 | tokenizer: PretrainedTokenizer = None,
|
| 629 | + **kwargs, |
632 | 630 | ):
|
633 |
| - self.cache_kvs_shape = cache_kvs_shape |
| 631 | + self.cache_kvs_shape = kwargs.get("cache_kvs_shape", None) |
| 632 | + if self.cache_kvs_shape is None: |
| 633 | + raise ValueError("cache_kvs_shape should be provided for StaticGraphInferencePredictor") |
634 | 634 | InferencePredictorMixin.__init__(self, config, tokenizer)
|
635 | 635 |
|
636 | 636 | self.predictor = self._create_predictor(config)
|
@@ -701,9 +701,12 @@ class DygraphInferencePredictor(InferencePredictorMixin):
|
701 | 701 | def __init__(
|
702 | 702 | self,
|
703 | 703 | config: PredictorArgument,
|
704 |
| - model: PretrainedModel = None, |
705 | 704 | tokenizer: PretrainedTokenizer = None,
|
| 705 | + **kwargs, |
706 | 706 | ):
|
| 707 | + model = kwargs.get("model", None) |
| 708 | + if model is None: |
| 709 | + raise ValueError("model should be provided for DygraphInferencePredictor") |
707 | 710 | self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size, config.total_max_length)
|
708 | 711 | InferencePredictorMixin.__init__(self, config, tokenizer)
|
709 | 712 | self.model = model
|
@@ -982,12 +985,10 @@ def _preprocess(self, input_text: list[str]):
|
982 | 985 |
|
983 | 986 |
|
984 | 987 | class DygraphBlockInferencePredictor(BlockInferencePredictorMixin):
|
985 |
| - def __init__( |
986 |
| - self, |
987 |
| - config: PredictorArgument, |
988 |
| - model: PretrainedModel = None, |
989 |
| - tokenizer: PretrainedTokenizer = None, |
990 |
| - ): |
| 988 | + def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs): |
| 989 | + model = kwargs.get("model", None) |
| 990 | + if model is None: |
| 991 | + raise ValueError("model should be provided for DygraphBlockInferencePredictor") |
991 | 992 | self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size)
|
992 | 993 | BlockInferencePredictorMixin.__init__(self, config, tokenizer)
|
993 | 994 |
|
@@ -1079,14 +1080,16 @@ def predict(self, input_texts: list[str], return_tokens=False):
|
1079 | 1080 | return outputs
|
1080 | 1081 |
|
1081 | 1082 |
|
1082 |
| -class StaticBlockInferencePredictor(BlockInferencePredictorMixin): |
| 1083 | +class StaticGraphBlockInferencePredictor(BlockInferencePredictorMixin): |
1083 | 1084 | def __init__(
|
1084 | 1085 | self,
|
1085 | 1086 | config: PredictorArgument,
|
1086 |
| - cache_kvs_shape: list[list[int]], |
1087 | 1087 | tokenizer: PretrainedTokenizer = None,
|
| 1088 | + **kwargs, |
1088 | 1089 | ):
|
1089 |
| - self.cache_kvs_shape = cache_kvs_shape |
| 1090 | + self.cache_kvs_shape = kwargs.get("cache_kvs_shape", None) |
| 1091 | + if self.cache_kvs_shape is None: |
| 1092 | + raise ValueError("cache_kvs_shape should be provided for StaticGraphBlockInferencePredictor") |
1090 | 1093 | BlockInferencePredictorMixin.__init__(self, config, tokenizer)
|
1091 | 1094 |
|
1092 | 1095 | self._create_predictor(config)
|
@@ -1224,21 +1227,71 @@ def predict(self, input_texts: list[str], return_tokens=False):
|
1224 | 1227 | return outputs
|
1225 | 1228 |
|
1226 | 1229 |
|
1227 |
| -def get_ptq_multicards_num(directory): |
1228 |
| - count = 0 |
1229 |
| - if os.path.exists(directory): |
1230 |
| - prefix = "act_scales_" |
1231 |
| - for filename in os.listdir(directory): |
1232 |
| - if filename.startswith(prefix): |
1233 |
| - count += 1 |
1234 |
| - return count |
| 1230 | +class AutoPredictor: |
| 1231 | + def __init__(self, *args, **kwargs): |
| 1232 | + raise EnvironmentError( |
| 1233 | + f"{self.__class__.__name__} is designed to be instantiated " |
| 1234 | + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path).`" |
| 1235 | + ) |
| 1236 | + |
| 1237 | + @classmethod |
| 1238 | + def create_predictor( |
| 1239 | + cls, |
| 1240 | + predictor_args: PredictorArgument, |
| 1241 | + config: PretrainedConfig, |
| 1242 | + model_args: ModelArgument, |
| 1243 | + tokenizer: PretrainedTokenizer = None, |
| 1244 | + **kwargs |
| 1245 | + ): |
| 1246 | + """ |
| 1247 | + Create a predictor |
| 1248 | +
|
| 1249 | + Args: |
| 1250 | + predictor_args (PredictorArgument): The predictor arguments. |
| 1251 | + config (PretrainedConfig): The model configuration. |
| 1252 | + model_args (ModelArgument): The model arguments. |
| 1253 | + tokenizer (PretrainedTokenizer): The tokenizer. |
| 1254 | + **kwargs: Additional keyword arguments. |
| 1255 | + Returns: |
| 1256 | + Predictor: The predictor. |
| 1257 | + """ |
| 1258 | + model = kwargs.pop("model", None) |
| 1259 | + cache_kvs_shape = None |
| 1260 | + |
| 1261 | + # static or dynamic |
| 1262 | + execute_mode = "Dygraph" if predictor_args.mode == "dynamic" else "StaticGraph" |
| 1263 | + |
| 1264 | + # infer/ no infer |
| 1265 | + if predictor_args.inference_model: |
| 1266 | + # block/no block |
| 1267 | + if predictor_args.block_attn: |
| 1268 | + attn_type = "Block" |
| 1269 | + else: |
| 1270 | + attn_type = "" |
| 1271 | + inference_mode = f"{attn_type}Inference" |
| 1272 | + |
| 1273 | + if predictor_args.mode == "static": |
| 1274 | + cache_kvs_shape = model.get_cache_kvs_shape( |
| 1275 | + config, predictor_args.batch_size, predictor_args.total_max_length |
| 1276 | + ) |
| 1277 | + else: |
| 1278 | + inference_mode = "" |
| 1279 | + |
| 1280 | + predictor_class_name = execute_mode + inference_mode + "Predictor" |
| 1281 | + |
| 1282 | + import_class = sys.modules[__name__] |
| 1283 | + |
| 1284 | + # import class |
| 1285 | + predictor_class = getattr(import_class, predictor_class_name) |
| 1286 | + |
| 1287 | + # instance |
| 1288 | + predictor = predictor_class(predictor_args, tokenizer=tokenizer, model=model, cache_kvs_shape=cache_kvs_shape) |
| 1289 | + return predictor |
1235 | 1290 |
|
1236 | 1291 |
|
1237 | 1292 | def create_predictor(
|
1238 | 1293 | predictor_args: PredictorArgument,
|
1239 | 1294 | model_args: ModelArgument,
|
1240 |
| - tensor_parallel_degree: int = 1, |
1241 |
| - tensor_parallel_rank: int = 0, |
1242 | 1295 | ):
|
1243 | 1296 | tokenizer = AutoTokenizer.from_pretrained(
|
1244 | 1297 | predictor_args.model_name_or_path,
|
@@ -1272,9 +1325,23 @@ def create_predictor(
|
1272 | 1325 | predictor_args.temperature = 1.0
|
1273 | 1326 |
|
1274 | 1327 | tensor_parallel_rank, tensor_parallel_degree = llm_utils.init_dist_env()
|
1275 |
| - if not predictor_args.inference_model: |
1276 |
| - tokenizer.padding_side = "left" |
| 1328 | + |
| 1329 | + model = None |
| 1330 | + |
| 1331 | + # model loading |
| 1332 | + if predictor_args.inference_model: |
| 1333 | + model = AutoInferenceModelForCausalLM.from_pretrained( |
| 1334 | + predictor_args.model_name_or_path, |
| 1335 | + config=config, |
| 1336 | + predictor_args=predictor_args, |
| 1337 | + model_args=model_args, |
| 1338 | + dtype=predictor_args.dtype, |
| 1339 | + tensor_parallel_degree=tensor_parallel_degree, |
| 1340 | + tensor_parallel_rank=tensor_parallel_rank, |
| 1341 | + ) |
| 1342 | + else: |
1277 | 1343 | if predictor_args.mode == "dynamic":
|
| 1344 | + # model import (gpt-3,ernie) or AutoModel |
1278 | 1345 | if model_args.model_type == "gpt-3":
|
1279 | 1346 | sys.path.append("./gpt-3")
|
1280 | 1347 | from modeling import GPTForCausalLM
|
@@ -1309,47 +1376,7 @@ def create_predictor(
|
1309 | 1376 | tensor_parallel_output=False,
|
1310 | 1377 | )
|
1311 | 1378 |
|
1312 |
| - predictor = DygraphPredictor(predictor_args, model=model, tokenizer=tokenizer) |
1313 |
| - elif predictor_args.mode == "static": |
1314 |
| - predictor = StaticGraphPredictor(predictor_args, tokenizer=tokenizer) |
1315 |
| - else: |
1316 |
| - raise ValueError("the `mode` should be one of [dynamic, static]") |
1317 |
| - else: |
1318 |
| - if predictor_args.mode == "dynamic": |
1319 |
| - model = AutoInferenceModelForCausalLM.from_pretrained( |
1320 |
| - predictor_args.model_name_or_path, |
1321 |
| - config=config, |
1322 |
| - predictor_args=predictor_args, |
1323 |
| - model_args=model_args, |
1324 |
| - dtype=predictor_args.dtype, |
1325 |
| - tensor_parallel_degree=tensor_parallel_degree, |
1326 |
| - tensor_parallel_rank=tensor_parallel_rank, |
1327 |
| - ) |
1328 |
| - model.eval() |
1329 |
| - if predictor_args.block_attn: |
1330 |
| - predictor = DygraphBlockInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) |
1331 |
| - else: |
1332 |
| - predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) |
1333 |
| - |
1334 |
| - elif predictor_args.mode == "static": |
1335 |
| - model = AutoInferenceModelForCausalLM.from_pretrained( |
1336 |
| - predictor_args.model_name_or_path, |
1337 |
| - config=config, |
1338 |
| - predictor_args=predictor_args, |
1339 |
| - model_args=model_args, |
1340 |
| - dtype=predictor_args.dtype, |
1341 |
| - tensor_parallel_degree=tensor_parallel_degree, |
1342 |
| - tensor_parallel_rank=tensor_parallel_rank, |
1343 |
| - ) |
1344 |
| - cache_kvs_shape = model.get_cache_kvs_shape( |
1345 |
| - config, predictor_args.batch_size, predictor_args.total_max_length |
1346 |
| - ) |
1347 |
| - if predictor_args.block_attn: |
1348 |
| - predictor = StaticBlockInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) |
1349 |
| - else: |
1350 |
| - predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) |
1351 |
| - else: |
1352 |
| - raise ValueError("the `mode` should be one of [dynamic, static]") |
| 1379 | + predictor = AutoPredictor.create_predictor(predictor_args, config, model_args, tokenizer, model=model) |
1353 | 1380 |
|
1354 | 1381 | return predictor
|
1355 | 1382 |
|
|
0 commit comments