Skip to content

Commit 520018c

Browse files
authored
Add RNN Classifier (#303)
* Add rnn classifier * Minor fix
1 parent b29cc54 commit 520018c

File tree

5 files changed

+468
-0
lines changed

5 files changed

+468
-0
lines changed

docs/code/modules.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ Classifiers
239239
.. autoclass:: texar.torch.modules.GPT2Classifier
240240
:members:
241241

242+
:hidden:`UnidirectionalRNNClassifier`
243+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
244+
.. autoclass:: texar.torch.modules.UnidirectionalRNNClassifier
245+
:members:
246+
242247
:hidden:`Conv1DClassifier`
243248
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
244249
.. autoclass:: texar.torch.modules.Conv1DClassifier
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright 2019 The Texar Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Unit tests for RNN classifiers.
16+
"""
17+
18+
import unittest
19+
20+
import torch
21+
22+
from texar.torch.modules.classifiers.rnn_classifiers import *
23+
24+
25+
class UnidirectionalRNNClassifierTest(unittest.TestCase):
26+
r"""Tests :class:`~texar.torch.modules.UnidirectionalRNNClassifier` class.
27+
"""
28+
29+
def setUp(self) -> None:
30+
self.batch_size = 2
31+
self.max_length = 3
32+
self.emb_dim = 4
33+
self.inputs = torch.rand(
34+
self.batch_size, self.max_length, self.emb_dim)
35+
36+
def test_trainable_variables(self):
37+
r"""Tests the functionality of automatically collecting trainable
38+
variables.
39+
"""
40+
# case 1
41+
classifier = UnidirectionalRNNClassifier(input_size=self.emb_dim)
42+
output, _ = classifier(self.inputs)
43+
self.assertEqual(len(classifier.trainable_variables), 4 + 2)
44+
self.assertEqual(output.size()[-1], classifier.output_size)
45+
46+
# case 2
47+
hparams = {
48+
"output_layer": {"num_layers": 2},
49+
"logit_layer_kwargs": {"bias": False}
50+
}
51+
classifier = UnidirectionalRNNClassifier(input_size=self.emb_dim,
52+
hparams=hparams)
53+
output, _ = classifier(self.inputs)
54+
self.assertEqual(len(classifier.trainable_variables), 4 + 2 + 2 + 1)
55+
56+
def test_encode(self):
57+
r"""Tests encoding.
58+
"""
59+
# case 1
60+
classifier = UnidirectionalRNNClassifier(input_size=self.emb_dim)
61+
logits, pred = classifier(self.inputs)
62+
self.assertEqual(logits.shape,
63+
torch.Size([self.batch_size,
64+
classifier.hparams.num_classes]))
65+
self.assertEqual(pred.shape, torch.Size([self.batch_size]))
66+
67+
# case 2
68+
hparams = {
69+
"num_classes": 10,
70+
"clas_strategy": "time_wise"
71+
}
72+
classifier = UnidirectionalRNNClassifier(input_size=self.emb_dim,
73+
hparams=hparams)
74+
logits, pred = classifier(self.inputs)
75+
self.assertEqual(logits.shape,
76+
torch.Size([self.batch_size, self.max_length,
77+
classifier.hparams.num_classes]))
78+
self.assertEqual(pred.shape,
79+
torch.Size([self.batch_size, self.max_length]))
80+
81+
# case 3
82+
hparams = {
83+
"output_layer": {
84+
"num_layers": 1,
85+
"layer_size": 10
86+
},
87+
"num_classes": 0,
88+
"clas_strategy": "time_wise"
89+
}
90+
classifier = UnidirectionalRNNClassifier(input_size=self.emb_dim,
91+
hparams=hparams)
92+
logits, pred = classifier(self.inputs)
93+
self.assertEqual(logits.shape,
94+
torch.Size([self.batch_size, self.max_length, 10]))
95+
self.assertEqual(pred.shape,
96+
torch.Size([self.batch_size, self.max_length]))
97+
98+
# case 4
99+
hparams = {
100+
"num_classes": 10,
101+
"clas_strategy": "all_time",
102+
"max_seq_length": 5
103+
}
104+
classifier = UnidirectionalRNNClassifier(input_size=self.emb_dim,
105+
hparams=hparams)
106+
logits, pred = classifier(self.inputs)
107+
self.assertEqual(logits.shape,
108+
torch.Size([self.batch_size,
109+
classifier.hparams.num_classes]))
110+
self.assertEqual(pred.shape, torch.Size([self.batch_size]))
111+
112+
def test_binary(self):
113+
r"""Tests binary classification.
114+
"""
115+
# case 1
116+
hparams = {
117+
"num_classes": 1,
118+
"clas_strategy": "time_wise"
119+
}
120+
classifier = UnidirectionalRNNClassifier(input_size=self.emb_dim,
121+
hparams=hparams)
122+
logits, pred = classifier(self.inputs)
123+
self.assertEqual(logits.shape,
124+
torch.Size([self.batch_size, self.max_length]))
125+
self.assertEqual(pred.shape,
126+
torch.Size([self.batch_size, self.max_length]))
127+
128+
# case 2
129+
hparams = {
130+
"output_layer": {
131+
"num_layers": 1,
132+
"layer_size": 10
133+
},
134+
"num_classes": 1,
135+
"clas_strategy": "time_wise"
136+
}
137+
classifier = UnidirectionalRNNClassifier(input_size=self.emb_dim,
138+
hparams=hparams)
139+
logits, pred = classifier(self.inputs)
140+
self.assertEqual(logits.shape,
141+
torch.Size([self.batch_size, self.max_length]))
142+
self.assertEqual(pred.shape,
143+
torch.Size([self.batch_size, self.max_length]))
144+
145+
# case 3
146+
hparams = {
147+
"num_classes": 1,
148+
"clas_strategy": "all_time",
149+
"max_seq_length": 5
150+
}
151+
classifier = UnidirectionalRNNClassifier(input_size=self.emb_dim,
152+
hparams=hparams)
153+
logits, pred = classifier(self.inputs)
154+
self.assertEqual(logits.shape,
155+
torch.Size([self.batch_size]))
156+
self.assertEqual(pred.shape,
157+
torch.Size([self.batch_size]))
158+
159+
160+
if __name__ == "__main__":
161+
unittest.main()

texar/torch/modules/classifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
from texar.torch.modules.classifiers.classifier_base import *
2020
from texar.torch.modules.classifiers.conv_classifiers import *
2121
from texar.torch.modules.classifiers.gpt2_classifier import *
22+
from texar.torch.modules.classifiers.rnn_classifiers import *
2223
from texar.torch.modules.classifiers.roberta_classifier import *
2324
from texar.torch.modules.classifiers.xlnet_classifier import *

0 commit comments

Comments
 (0)