1414# limitations under the License.
1515# ******************************************************************************
1616import tensorflow as tf
17- from tensorflow import convert_to_tensor , keras
1817
1918
20- class CRF (keras .layers .Layer ):
19+ class CRF (tf . keras .layers .Layer ):
2120 """
2221 Conditional Random Field layer (tf.keras)
2322 `CRF` can be used as the last layer in a network (as a classifier). Input shape (features)
@@ -29,55 +28,36 @@ class CRF(keras.layers.Layer):
2928
3029 Args:
3130 num_labels (int): the number of labels to tag each temporal input.
32- mode (string, optional): operation mode, 'reg' for regular full sequence learning (all
33- sequences have equal length), or 'pad' for using with supplied sequence lengths (useful
34- for padded sequences)
3531
3632 Input shape:
37- 'reg' mode - nD tensor with shape `(batch_size, sentence length, num_classes)`.
38- 'pad' mode - tuple of `(batch_size, sentence length, num_classes)`, `(batch_size, 1)`
33+ nD tensor with shape `(batch_size, sentence length, num_classes)`.
3934
4035 Output shape:
4136 nD tensor with shape: `(batch_size, sentence length, num_classes)`.
4237 """
43- def __init__ (self , num_classes , mode = 'reg' , ** kwargs ):
38+
39+ def __init__ (self , num_classes , ** kwargs ):
4440 self .transitions = None
4541 super (CRF , self ).__init__ (** kwargs )
4642 # num of output labels
4743 self .output_dim = int (num_classes )
48- self .mode = mode
49- if self .mode == 'pad' :
50- self .input_spec = [keras .layers .InputSpec (min_ndim = 3 ),
51- keras .layers .InputSpec (min_ndim = 2 )]
52- elif self .mode == 'reg' :
53- self .input_spec = keras .layers .InputSpec (min_ndim = 3 )
54- else :
55- raise ValueError
56- self .supports_masking = True
44+ self .input_spec = tf .keras .layers .InputSpec (min_ndim = 3 )
45+ self .supports_masking = False
5746 self .sequence_lengths = None
5847
5948 def get_config (self ):
6049 config = {
6150 'output_dim' : self .output_dim ,
62- 'mode' : self .mode ,
6351 'supports_masking' : self .supports_masking ,
6452 'transitions' : tf .keras .backend .eval (self .transitions )
6553 }
6654 base_config = super (CRF , self ).get_config ()
6755 return dict (list (base_config .items ()) + list (config .items ()))
6856
6957 def build (self , input_shape ):
70- if self .mode == 'pad' :
71- assert len (input_shape ) == 2
72- assert len (input_shape [0 ]) == 3
73- assert len (input_shape [1 ]) == 2
74- f_shape = tf .TensorShape (input_shape [0 ])
75- input_spec = [keras .layers .InputSpec (min_ndim = 3 , axes = {- 1 : f_shape [- 1 ]}),
76- keras .layers .InputSpec (min_ndim = 2 , axes = {- 1 : 1 }, dtype = tf .int32 )]
77- else :
78- assert len (input_shape ) == 3
79- f_shape = tf .TensorShape (input_shape )
80- input_spec = keras .layers .InputSpec (min_ndim = 3 , axes = {- 1 : f_shape [- 1 ]})
58+ assert len (input_shape ) == 3
59+ f_shape = tf .TensorShape (input_shape )
60+ input_spec = tf .keras .layers .InputSpec (min_ndim = 3 , axes = {- 1 : f_shape [- 1 ]})
8161
8262 if f_shape [- 1 ] is None :
8363 raise ValueError ('The last dimension of the inputs to `CRF` '
@@ -92,21 +72,26 @@ def build(self, input_shape):
9272 trainable = True )
9373 self .built = True
9474
95- def call (self , inputs , ** kwargs ):
96- if self .mode == 'pad' :
97- sequences = convert_to_tensor (inputs [0 ], dtype = self .dtype )
98- self .sequence_lengths = tf .keras .backend .flatten (inputs [- 1 ])
75+ # pylint: disable=arguments-differ
76+ def call (self , inputs , sequence_lengths = None , ** kwargs ):
77+ sequences = tf .convert_to_tensor (inputs , dtype = self .dtype )
78+ if sequence_lengths is not None :
79+ assert len (sequence_lengths .shape ) == 2
80+ assert tf .convert_to_tensor (sequence_lengths ).dtype == 'int32'
81+ seq_len_shape = tf .convert_to_tensor (sequence_lengths ).get_shape ().as_list ()
82+ assert seq_len_shape [1 ] == 1
83+ self .sequence_lengths = tf .keras .backend .flatten (sequence_lengths )
9984 else :
100- sequences = convert_to_tensor ( inputs , dtype = self . dtype )
101- shape = tf .shape (inputs )
102- self . sequence_lengths = tf . ones ( shape [ 0 ], dtype = tf . int32 ) * ( shape [ 1 ])
85+ self . sequence_lengths = tf . ones ( tf . shape ( inputs )[ 0 ] , dtype = tf . int32 ) * \
86+ ( tf .shape (inputs )[ 1 ] )
87+
10388 viterbi_sequence , _ = tf .contrib .crf .crf_decode (sequences , self .transitions ,
10489 self .sequence_lengths )
105- output = keras .backend .one_hot (viterbi_sequence , self .output_dim )
106- return keras .backend .in_train_phase (sequences , output )
90+ output = tf . keras .backend .one_hot (viterbi_sequence , self .output_dim )
91+ return tf . keras .backend .in_train_phase (sequences , output )
10792
10893 def loss (self , y_true , y_pred ):
109- y_pred = convert_to_tensor (y_pred , dtype = self .dtype )
94+ y_pred = tf . convert_to_tensor (y_pred , dtype = self .dtype )
11095 log_likelihood , self .transitions = \
11196 tf .contrib .crf .crf_log_likelihood (y_pred ,
11297 tf .cast (tf .keras .backend .argmax (y_true ),
@@ -116,12 +101,8 @@ def loss(self, y_true, y_pred):
116101 return tf .reduce_mean (- log_likelihood )
117102
118103 def compute_output_shape (self , input_shape ):
119- if self .mode == 'pad' :
120- data_shape = input_shape [0 ]
121- else :
122- data_shape = input_shape
123- tf .TensorShape (data_shape ).assert_has_rank (3 )
124- return data_shape [:2 ] + (self .output_dim ,)
104+ tf .TensorShape (input_shape ).assert_has_rank (3 )
105+ return input_shape [:2 ] + (self .output_dim ,)
125106
126107 @property
127108 def viterbi_accuracy (self ):
@@ -130,7 +111,7 @@ def accuracy(y_true, y_pred):
130111 sequence_lengths = tf .ones (shape [0 ], dtype = tf .int32 ) * (shape [1 ])
131112 viterbi_sequence , _ = tf .contrib .crf .crf_decode (y_pred , self .transitions ,
132113 sequence_lengths )
133- output = keras .backend .one_hot (viterbi_sequence , self .output_dim )
114+ output = tf . keras .backend .one_hot (viterbi_sequence , self .output_dim )
134115 return tf .keras .metrics .categorical_accuracy (y_true , output )
135116 accuracy .func_name = 'viterbi_accuracy'
136117 return accuracy
0 commit comments