Skip to content

Commit a6a79c3

Browse files
committed
More general implementation.
1 parent 5ecbba4 commit a6a79c3

File tree

1 file changed

+8
-4
lines changed
  • python/paddle/v2/fluid

1 file changed

+8
-4
lines changed

python/paddle/v2/fluid/io.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,16 @@ def load_persistables(executor, dirname, main_program=None):
186186
def get_inference_program(target_vars, main_program=None):
187187
if main_program is None:
188188
main_program = default_main_program()
189-
if isinstance(target_vars, Evaluator):
190-
target_vars = target_vars.states + target_vars.metrics
191189
if not isinstance(target_vars, list):
192190
target_vars = [target_vars]
193-
194-
pruned_program = main_program.prune(targets=target_vars)
191+
vars = []
192+
for var in target_vars:
193+
if isinstance(var, Evaluator):
194+
vars.append(var.states)
195+
vars.append(var.metrics)
196+
else:
197+
vars.append(var)
198+
pruned_program = main_program.prune(targets=vars)
195199
inference_program = pruned_program.inference_optimize()
196200
return inference_program
197201

0 commit comments

Comments
 (0)