Skip to content

Commit 7a8fddd

Browse files
[PIR] Support pir export and infer (#952)
* support export and inference in PIR mode * add validator check
1 parent 4ced874 commit 7a8fddd

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

deploy/python_infer/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing_extensions import Literal
2626

2727
from ppsci.utils import logger
28+
from ppsci.utils import misc
2829

2930
if TYPE_CHECKING:
3031
import onnxruntime
@@ -99,15 +100,19 @@ def predict(self, input_dict):
99100
def _create_paddle_predictor(
100101
self,
101102
) -> Tuple[paddle_inference.Predictor, paddle_inference.Config]:
103+
if misc.check_flag_enabled("FLAGS_enable_pir_api"):
104+
# NOTE: Using 'json' as suffix instead of 'pdmodel' in PIR mode
105+
self.pdmodel_path = self.pdmodel_path.replace(".pdmodel", ".json", 1)
106+
102107
if not osp.exists(self.pdmodel_path):
103108
raise FileNotFoundError(
104109
f"Given 'pdmodel_path': {self.pdmodel_path} does not exist. "
105-
"Please check if it is correct."
110+
"Please check if cfg.INFER.pdmodel_path is correct."
106111
)
107112
if not osp.exists(self.pdiparams_path):
108113
raise FileNotFoundError(
109114
f"Given 'pdiparams_path': {self.pdiparams_path} does not exist. "
110-
"Please check if it is correct."
115+
"Please check if cfg.INFER.pdiparams_path is correct."
111116
)
112117

113118
config = paddle_inference.Config(self.pdmodel_path, self.pdiparams_path)

ppsci/solver/solver.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,11 +913,18 @@ def export(
913913
raise e
914914
logger.message(
915915
f"Inference model has been exported to: {export_path}, including "
916-
"*.pdmodel, *.pdiparams and *.pdiparams.info files."
916+
+ (
917+
"*.json, *.pdiparams files."
918+
if misc.check_flag_enabled("FLAGS_enable_pir_api")
919+
else "*.pdmodel, *.pdiparams and *.pdiparams.info files."
920+
)
917921
)
918922
jit.enable_to_static(False)
919923

920924
if with_onnx:
925+
# TODO: support pir + onnx
926+
if misc.check_flag_enabled("FLAGS_enable_pir_api"):
927+
raise ValueError("paddle2onnx does not support PIR mode yet.")
921928
if not importlib.util.find_spec("paddle2onnx"):
922929
raise ModuleNotFoundError(
923930
"Please install paddle2onnx with `pip install paddle2onnx`"

ppsci/utils/misc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"run_on_eval_mode",
5353
"run_at_rank0",
5454
"plot_curve",
55+
"check_flag_enabled",
5556
]
5657

5758

@@ -631,3 +632,18 @@ def plot_curve(
631632
plt.savefig(os.path.join(output_dir, f"{xlabel}-{ylabel}_curve.jpg"), dpi=200)
632633
plt.clf()
633634
plt.close()
635+
636+
637+
def check_flag_enabled(flag_name: str) -> bool:
638+
"""Check whether the flag is enabled.
639+
640+
Args:
641+
flag_name(str): Flag name to be checked whether enabled or disabled.
642+
643+
Returns:
644+
bool: Whether given flag name is enabled in environment.
645+
"""
646+
value = os.getenv(flag_name, False)
647+
if isinstance(value, str):
648+
return value.lower() in ["true", "1"]
649+
return False

0 commit comments

Comments
 (0)