Skip to content

Commit 0b854bd

Browse files
committed
Add sequence_erase option into edit distance python API
1 parent a8f118c commit 0b854bd

File tree

1 file changed

+22
-1
lines changed
  • python/paddle/v2/fluid/layers

1 file changed

+22
-1
lines changed

python/paddle/v2/fluid/layers/nn.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1864,7 +1864,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
18641864
return out
18651865

18661866

1867-
def edit_distance(input, label, normalized=False, name=None):
1867+
def edit_distance(input, label, normalized=False, tokens=None, name=None):
18681868
"""
18691869
EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
18701870
@@ -1882,6 +1882,8 @@ def edit_distance(input, label, normalized=False, name=None):
18821882
18831883
normalized(bool): Indicated whether to normalize the edit distance by the length of reference string.
18841884
1885+
tokens(list): Tokens that should be removed before calculating edit distance.
1886+
18851887
Returns:
18861888
Variable: sequence-to-sequence edit distance loss in shape [batch_size, 1].
18871889
@@ -1895,6 +1897,25 @@ def edit_distance(input, label, normalized=False, name=None):
18951897
"""
18961898
helper = LayerHelper("edit_distance", **locals())
18971899

1900+
# remove some tokens from input and labels
1901+
if tokens is not None and len(tokens) > 0:
1902+
erased_input = helper.create_tmp_variable(dtype="int64")
1903+
erased_label = helper.create_tmp_variable(dtype="int64")
1904+
1905+
helper.append_op(
1906+
type="sequence_erase",
1907+
inputs={"X": [input]},
1908+
outputs={"Out": [erased_input]},
1909+
attrs={"tokens": tokens})
1910+
input = erased_input
1911+
1912+
helper.append_op(
1913+
type="sequence_erase",
1914+
inputs={"X": [label]},
1915+
outputs={"Out": [erase_label]},
1916+
attrs={"tokens": tokens})
1917+
label = erased_label
1918+
18981919
# edit distance op
18991920
edit_distance_out = helper.create_tmp_variable(dtype="int64")
19001921
helper.append_op(

0 commit comments

Comments
 (0)