File tree Expand file tree Collapse file tree 3 files changed +31
-3
lines changed Expand file tree Collapse file tree 3 files changed +31
-3
lines changed Original file line number Diff line number Diff line change 25
25
from typing_extensions import Literal
26
26
27
27
from ppsci .utils import logger
28
+ from ppsci .utils import misc
28
29
29
30
if TYPE_CHECKING :
30
31
import onnxruntime
@@ -99,15 +100,19 @@ def predict(self, input_dict):
99
100
def _create_paddle_predictor (
100
101
self ,
101
102
) -> Tuple [paddle_inference .Predictor , paddle_inference .Config ]:
103
+ if misc .check_flag_enabled ("FLAGS_enable_pir_api" ):
104
+ # NOTE: Using 'json' as suffix instead of 'pdmodel' in PIR mode
105
+ self .pdmodel_path = self .pdmodel_path .replace (".pdmodel" , ".json" , 1 )
106
+
102
107
if not osp .exists (self .pdmodel_path ):
103
108
raise FileNotFoundError (
104
109
f"Given 'pdmodel_path': { self .pdmodel_path } does not exist. "
105
- "Please check if it is correct."
110
+ "Please check if cfg.INFER.pdmodel_path is correct."
106
111
)
107
112
if not osp .exists (self .pdiparams_path ):
108
113
raise FileNotFoundError (
109
114
f"Given 'pdiparams_path': { self .pdiparams_path } does not exist. "
110
- "Please check if it is correct."
115
+ "Please check if cfg.INFER.pdiparams_path is correct."
111
116
)
112
117
113
118
config = paddle_inference .Config (self .pdmodel_path , self .pdiparams_path )
Original file line number Diff line number Diff line change @@ -913,11 +913,18 @@ def export(
913
913
raise e
914
914
logger .message (
915
915
f"Inference model has been exported to: { export_path } , including "
916
- "*.pdmodel, *.pdiparams and *.pdiparams.info files."
916
+ + (
917
+ "*.json, *.pdiparams files."
918
+ if misc .check_flag_enabled ("FLAGS_enable_pir_api" )
919
+ else "*.pdmodel, *.pdiparams and *.pdiparams.info files."
920
+ )
917
921
)
918
922
jit .enable_to_static (False )
919
923
920
924
if with_onnx :
925
+ # TODO: support pir + onnx
926
+ if misc .check_flag_enabled ("FLAGS_enable_pir_api" ):
927
+ raise ValueError ("paddle2onnx does not support PIR mode yet." )
921
928
if not importlib .util .find_spec ("paddle2onnx" ):
922
929
raise ModuleNotFoundError (
923
930
"Please install paddle2onnx with `pip install paddle2onnx`"
Original file line number Diff line number Diff line change 52
52
"run_on_eval_mode" ,
53
53
"run_at_rank0" ,
54
54
"plot_curve" ,
55
+ "check_flag_enabled" ,
55
56
]
56
57
57
58
@@ -631,3 +632,18 @@ def plot_curve(
631
632
plt .savefig (os .path .join (output_dir , f"{ xlabel } -{ ylabel } _curve.jpg" ), dpi = 200 )
632
633
plt .clf ()
633
634
plt .close ()
635
+
636
+
637
+ def check_flag_enabled (flag_name : str ) -> bool :
638
+ """Check whether the flag is enabled.
639
+
640
+ Args:
641
+ flag_name(str): Flag name to be checked whether enabled or disabled.
642
+
643
+ Returns:
644
+ bool: Whether given flag name is enabled in environment.
645
+ """
646
+ value = os .getenv (flag_name , False )
647
+ if isinstance (value , str ):
648
+ return value .lower () in ["true" , "1" ]
649
+ return False
You can’t perform that action at this time.
0 commit comments