Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 52e3de1

Browse files
authored
fix sig (#1368)
1 parent 7f29267 commit 52e3de1

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

scripts/machine_translation/gnmt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __call__(self, inputs, states=None, valid_length=None):
133133
"""
134134
return super(GNMTEncoder, self).__call__(inputs, states, valid_length)
135135

136-
def forward(self, inputs, states=None, valid_length=None): #pylint: disable=arguments-differ, missing-docstring
136+
def forward(self, inputs, states=None, valid_length=None): #pylint: missing-docstring
137137
# TODO(sxjscience) Accelerate the forward using HybridBlock
138138
_, length, _ = inputs.shape
139139
new_states = []

src/gluonnlp/model/bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def __init__(self, *, num_layers=2, units=512, hidden_size=2048,
340340
activation=activation, layer_norm_eps=layer_norm_eps)
341341
self.transformer_cells.add(cell)
342342

343-
def __call__(self, inputs, states=None, valid_length=None): # pylint: disable=arguments-differ
343+
def __call__(self, inputs, states=None, valid_length=None):
344344
"""Encode the inputs given the states and valid sequence length.
345345
346346
Parameters

src/gluonnlp/model/seq2seq_encoder_decoder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,27 @@ class Seq2SeqEncoder(Block):
5353
r"""Base class of the encoders in sequence to sequence learning models.
5454
"""
5555

56-
def __call__(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ
56+
def __call__(self, inputs, states=None, valid_length=None): #pylint: disable=arguments-differ
5757
"""Encode the input sequence.
5858
5959
Parameters
6060
----------
6161
inputs : NDArray
6262
The input sequence, Shape (batch_size, length, C_in).
63+
states : list of NDArrays or None, default None
64+
List that contains the initial states of the encoder.
6365
valid_length : NDArray or None, default None
6466
The valid length of the input sequence, Shape (batch_size,). This is used when the
6567
input sequences are padded. If set to None, all elements in the sequence are used.
66-
states : list of NDArrays or None, default None
67-
List that contains the initial states of the encoder.
6868
6969
Returns
7070
-------
7171
outputs : list
7272
Outputs of the encoder.
7373
"""
74-
return super(Seq2SeqEncoder, self).__call__(inputs, valid_length, states)
74+
return super(Seq2SeqEncoder, self).__call__(inputs, states, valid_length)
7575

76-
def forward(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ
76+
def forward(self, inputs, states=None, valid_length=None): #pylint: disable=arguments-differ
7777
raise NotImplementedError
7878

7979

src/gluonnlp/model/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def __init__(self, *, attention_cell='multi_head', num_layers=2, units=512, hidd
344344
scaled=scaled, output_attention=output_attention, prefix='transformer%d_' % i)
345345
self.transformer_cells.add(cell)
346346

347-
def __call__(self, inputs, states=None, valid_length=None): #pylint: disable=arguments-differ
347+
def __call__(self, inputs, states=None, valid_length=None):
348348
"""Encode the inputs given the states and valid sequence length.
349349
350350
Parameters

0 commit comments

Comments
 (0)