Skip to content

Commit 2c1387f

Browse files
authored
[INFER][LLM] Add the AutoPredictor for inference (#9445)
* add the AutoPredictor * decoupling the model loading and predictor loading * polish the AutoPredictor and AutoModel
1 parent 7a221cc commit 2c1387f

File tree

3 files changed

+103
-74
lines changed

3 files changed

+103
-74
lines changed

llm/predict/export_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def main():
5858
tensor_parallel_rank = hcg.get_model_parallel_rank()
5959

6060
# set predictor type
61-
predictor = create_predictor(predictor_args, model_args, tensor_parallel_degree, tensor_parallel_rank)
61+
predictor = create_predictor(predictor_args, model_args)
6262
predictor.model.eval()
6363

6464
predictor.model.to_static(

llm/predict/predictor.py

Lines changed: 99 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
ChatGLMv2Tokenizer,
4242
Llama3Tokenizer,
4343
LlamaTokenizer,
44-
PretrainedModel,
44+
PretrainedConfig,
4545
PretrainedTokenizer,
4646
)
4747
from paddlenlp.trl import llm_utils
@@ -245,11 +245,9 @@ def predict(self, input_texts: str | list[str], return_tokens=False):
245245

246246

247247
class DygraphPredictor(BasePredictor):
248-
def __init__(
249-
self, config: PredictorArgument, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None
250-
):
248+
def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs):
251249
super().__init__(config, tokenizer)
252-
self.model = model
250+
self.model = kwargs.get("model", None)
253251
if config.lora_path is not None:
254252
lora_config = LoRAConfig.from_pretrained(config.lora_path)
255253
dtype = lora_config.dtype
@@ -326,7 +324,7 @@ def stream_predict(self, inputs: dict[str, paddle.Tensor]):
326324

327325

328326
class StaticGraphPredictor(BasePredictor):
329-
def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None):
327+
def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs):
330328
super().__init__(config, tokenizer)
331329

332330
inference_config = paddle.inference.Config(self.config.model_name_or_path, self.config.model_prefix)
@@ -623,14 +621,16 @@ def _preprocess(self, source):
623621
return inputs
624622

625623

626-
class StaticInferencePredictor(InferencePredictorMixin):
624+
class StaticGraphInferencePredictor(InferencePredictorMixin):
627625
def __init__(
628626
self,
629627
config: PredictorArgument,
630-
cache_kvs_shape: list[list[int]],
631628
tokenizer: PretrainedTokenizer = None,
629+
**kwargs,
632630
):
633-
self.cache_kvs_shape = cache_kvs_shape
631+
self.cache_kvs_shape = kwargs.get("cache_kvs_shape", None)
632+
if self.cache_kvs_shape is None:
633+
raise ValueError("cache_kvs_shape should be provided for StaticGraphInferencePredictor")
634634
InferencePredictorMixin.__init__(self, config, tokenizer)
635635

636636
self.predictor = self._create_predictor(config)
@@ -701,9 +701,12 @@ class DygraphInferencePredictor(InferencePredictorMixin):
701701
def __init__(
702702
self,
703703
config: PredictorArgument,
704-
model: PretrainedModel = None,
705704
tokenizer: PretrainedTokenizer = None,
705+
**kwargs,
706706
):
707+
model = kwargs.get("model", None)
708+
if model is None:
709+
raise ValueError("model should be provided for DygraphInferencePredictor")
707710
self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size, config.total_max_length)
708711
InferencePredictorMixin.__init__(self, config, tokenizer)
709712
self.model = model
@@ -982,12 +985,10 @@ def _preprocess(self, input_text: list[str]):
982985

983986

984987
class DygraphBlockInferencePredictor(BlockInferencePredictorMixin):
985-
def __init__(
986-
self,
987-
config: PredictorArgument,
988-
model: PretrainedModel = None,
989-
tokenizer: PretrainedTokenizer = None,
990-
):
988+
def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs):
989+
model = kwargs.get("model", None)
990+
if model is None:
991+
raise ValueError("model should be provided for DygraphBlockInferencePredictor")
991992
self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size)
992993
BlockInferencePredictorMixin.__init__(self, config, tokenizer)
993994

@@ -1079,14 +1080,16 @@ def predict(self, input_texts: list[str], return_tokens=False):
10791080
return outputs
10801081

10811082

1082-
class StaticBlockInferencePredictor(BlockInferencePredictorMixin):
1083+
class StaticGraphBlockInferencePredictor(BlockInferencePredictorMixin):
10831084
def __init__(
10841085
self,
10851086
config: PredictorArgument,
1086-
cache_kvs_shape: list[list[int]],
10871087
tokenizer: PretrainedTokenizer = None,
1088+
**kwargs,
10881089
):
1089-
self.cache_kvs_shape = cache_kvs_shape
1090+
self.cache_kvs_shape = kwargs.get("cache_kvs_shape", None)
1091+
if self.cache_kvs_shape is None:
1092+
raise ValueError("cache_kvs_shape should be provided for StaticGraphBlockInferencePredictor")
10901093
BlockInferencePredictorMixin.__init__(self, config, tokenizer)
10911094

10921095
self._create_predictor(config)
@@ -1224,21 +1227,71 @@ def predict(self, input_texts: list[str], return_tokens=False):
12241227
return outputs
12251228

12261229

1227-
def get_ptq_multicards_num(directory):
1228-
count = 0
1229-
if os.path.exists(directory):
1230-
prefix = "act_scales_"
1231-
for filename in os.listdir(directory):
1232-
if filename.startswith(prefix):
1233-
count += 1
1234-
return count
1230+
class AutoPredictor:
1231+
def __init__(self, *args, **kwargs):
1232+
raise EnvironmentError(
1233+
f"{self.__class__.__name__} is designed to be instantiated "
1234+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path).`"
1235+
)
1236+
1237+
@classmethod
1238+
def create_predictor(
1239+
cls,
1240+
predictor_args: PredictorArgument,
1241+
config: PretrainedConfig,
1242+
model_args: ModelArgument,
1243+
tokenizer: PretrainedTokenizer = None,
1244+
**kwargs
1245+
):
1246+
"""
1247+
Create a predictor
1248+
1249+
Args:
1250+
predictor_args (PredictorArgument): The predictor arguments.
1251+
config (PretrainedConfig): The model configuration.
1252+
model_args (ModelArgument): The model arguments.
1253+
tokenizer (PretrainedTokenizer): The tokenizer.
1254+
**kwargs: Additional keyword arguments.
1255+
Returns:
1256+
Predictor: The predictor.
1257+
"""
1258+
model = kwargs.pop("model", None)
1259+
cache_kvs_shape = None
1260+
1261+
# static or dynamic
1262+
execute_mode = "Dygraph" if predictor_args.mode == "dynamic" else "StaticGraph"
1263+
1264+
# infer/ no infer
1265+
if predictor_args.inference_model:
1266+
# block/no block
1267+
if predictor_args.block_attn:
1268+
attn_type = "Block"
1269+
else:
1270+
attn_type = ""
1271+
inference_mode = f"{attn_type}Inference"
1272+
1273+
if predictor_args.mode == "static":
1274+
cache_kvs_shape = model.get_cache_kvs_shape(
1275+
config, predictor_args.batch_size, predictor_args.total_max_length
1276+
)
1277+
else:
1278+
inference_mode = ""
1279+
1280+
predictor_class_name = execute_mode + inference_mode + "Predictor"
1281+
1282+
import_class = sys.modules[__name__]
1283+
1284+
# import class
1285+
predictor_class = getattr(import_class, predictor_class_name)
1286+
1287+
# instance
1288+
predictor = predictor_class(predictor_args, tokenizer=tokenizer, model=model, cache_kvs_shape=cache_kvs_shape)
1289+
return predictor
12351290

12361291

12371292
def create_predictor(
12381293
predictor_args: PredictorArgument,
12391294
model_args: ModelArgument,
1240-
tensor_parallel_degree: int = 1,
1241-
tensor_parallel_rank: int = 0,
12421295
):
12431296
tokenizer = AutoTokenizer.from_pretrained(
12441297
predictor_args.model_name_or_path,
@@ -1272,9 +1325,23 @@ def create_predictor(
12721325
predictor_args.temperature = 1.0
12731326

12741327
tensor_parallel_rank, tensor_parallel_degree = llm_utils.init_dist_env()
1275-
if not predictor_args.inference_model:
1276-
tokenizer.padding_side = "left"
1328+
1329+
model = None
1330+
1331+
# model loading
1332+
if predictor_args.inference_model:
1333+
model = AutoInferenceModelForCausalLM.from_pretrained(
1334+
predictor_args.model_name_or_path,
1335+
config=config,
1336+
predictor_args=predictor_args,
1337+
model_args=model_args,
1338+
dtype=predictor_args.dtype,
1339+
tensor_parallel_degree=tensor_parallel_degree,
1340+
tensor_parallel_rank=tensor_parallel_rank,
1341+
)
1342+
else:
12771343
if predictor_args.mode == "dynamic":
1344+
# model import (gpt-3,ernie) or AutoModel
12781345
if model_args.model_type == "gpt-3":
12791346
sys.path.append("./gpt-3")
12801347
from modeling import GPTForCausalLM
@@ -1309,47 +1376,7 @@ def create_predictor(
13091376
tensor_parallel_output=False,
13101377
)
13111378

1312-
predictor = DygraphPredictor(predictor_args, model=model, tokenizer=tokenizer)
1313-
elif predictor_args.mode == "static":
1314-
predictor = StaticGraphPredictor(predictor_args, tokenizer=tokenizer)
1315-
else:
1316-
raise ValueError("the `mode` should be one of [dynamic, static]")
1317-
else:
1318-
if predictor_args.mode == "dynamic":
1319-
model = AutoInferenceModelForCausalLM.from_pretrained(
1320-
predictor_args.model_name_or_path,
1321-
config=config,
1322-
predictor_args=predictor_args,
1323-
model_args=model_args,
1324-
dtype=predictor_args.dtype,
1325-
tensor_parallel_degree=tensor_parallel_degree,
1326-
tensor_parallel_rank=tensor_parallel_rank,
1327-
)
1328-
model.eval()
1329-
if predictor_args.block_attn:
1330-
predictor = DygraphBlockInferencePredictor(predictor_args, model=model, tokenizer=tokenizer)
1331-
else:
1332-
predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer)
1333-
1334-
elif predictor_args.mode == "static":
1335-
model = AutoInferenceModelForCausalLM.from_pretrained(
1336-
predictor_args.model_name_or_path,
1337-
config=config,
1338-
predictor_args=predictor_args,
1339-
model_args=model_args,
1340-
dtype=predictor_args.dtype,
1341-
tensor_parallel_degree=tensor_parallel_degree,
1342-
tensor_parallel_rank=tensor_parallel_rank,
1343-
)
1344-
cache_kvs_shape = model.get_cache_kvs_shape(
1345-
config, predictor_args.batch_size, predictor_args.total_max_length
1346-
)
1347-
if predictor_args.block_attn:
1348-
predictor = StaticBlockInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
1349-
else:
1350-
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
1351-
else:
1352-
raise ValueError("the `mode` should be one of [dynamic, static]")
1379+
predictor = AutoPredictor.create_predictor(predictor_args, config, model_args, tokenizer, model=model)
13531380

13541381
return predictor
13551382

paddlenlp/transformers/auto/modeling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
858858
)
859859

860860
if predictor_args.mode == "dynamic":
861-
return model_class.from_pretrained(predictor_args.model_name_or_path, config=config, dtype=dtype)
861+
model = model_class.from_pretrained(predictor_args.model_name_or_path, config=config, dtype=dtype)
862+
model.eval()
863+
return model
862864

863865
return model_class
864866

0 commit comments

Comments
 (0)