Skip to content

Commit 4274883

Browse files
committed
add field "prob" in paddle.infer
1 parent aa230bf commit 4274883

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

paddle/py_paddle/util.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,17 @@ def __arguments_to_numpy__(i, arg):
8383
assert isinstance(arg, swig_paddle.Arguments)
8484
value = arg.getSlotValue(i)
8585
ids = arg.getSlotIds(i)
86+
prob = arg.getSlotIn(i)
8687
if value is not None:
8788
assert isinstance(value, swig_paddle.Matrix)
8889
value = value.copyToNumpyMat()
8990
if ids is not None:
9091
assert isinstance(ids, swig_paddle.IVector)
9192
ids = ids.copyToNumpyArray()
92-
return {"value": value, "id": ids}
93+
if prob is not None:
94+
assert isinstance(prob, swig_paddle.Matrix)
95+
prob = prob.copyToNumpyMat()
96+
return {"value": value, "id": ids, "prob": prob}
9397

9498

9599
def __monkeypatch_gradient_machine__():

python/paddle/v2/inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ def infer(output_layer, parameters, input, feeding=None, field='value'):
8181
:type input: collections.Iterable
8282
:param feeding: Reader dictionary. Default could generate from input
8383
value.
84-
:param field: The prediction field. It should in [`value`, `ids`]. `value`
85-
means return the prediction probabilities, `ids` means return
86-
the prediction labels. Default is `value`
84+
:param field: The prediction field. It should in [`value`, `id`, `prob`].
85+
`value` and `prob` mean return the prediction probabilities,
86+
`id` means return the prediction labels. Default is `value`.
87+
Note that `prob` only used when output_layer is beam_search
88+
or max_id.
8789
:type field: str
8890
:return: a numpy array
8991
:rtype: numpy.ndarray

0 commit comments

Comments
 (0)