Skip to content

Commit 8097309

Browse files
authored
Bugfix in VAE example (#253)
* Bugfix in vae_text example * Bugfix in generation part of vae_text example * Fix pylint error
1 parent 53d8ead commit 8097309

File tree

11 files changed

+66
-12
lines changed

11 files changed

+66
-12
lines changed

examples/vae_text/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
/data/
33
/models/
44
/simple-examples.tgz
5+
/yahoo.zip

examples/vae_text/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
This example builds a VAE for text generation, with an LSTM as encoder and an LSTM or [Transformer](https://arxiv.org/pdf/1706.03762.pdf) as decoder. Training is performed on the official PTB data and Yahoo data, respectively.
44

5-
The VAE with LSTM decoder is first decribed in [(Bowman et al., 2015) Generating Sentences from a Continuous Space](https://arxiv.org/pdf/1511.06349.pdf)
5+
The VAE with LSTM decoder is first described in [(Bowman et al., 2015) Generating Sentences from a Continuous Space](https://arxiv.org/pdf/1511.06349.pdf)
66

7-
The Yahoo dataset is from [(Yang et al., 2017) Improved Variational Autoencoders for Text Modeling using Dilated Convolutions](https://arxiv.org/pdf/1702.08139.pdf), which is created by sampling 100k documents from the original Yahoo Answer data. The average document length is 78 and the vocab size is 200k.
7+
The Yahoo dataset is from [(Yang et al., 2017) Improved Variational Autoencoders for Text Modeling using Dilated Convolutions](https://arxiv.org/pdf/1702.08139.pdf), which is created by sampling 100k documents from the original Yahoo Answer data. The average document length is 78 and the vocabulary size is 200k.
88

99
## Data
1010
The datasets can be downloaded by running:
11+
1112
```shell
1213
python prepare_data.py --data ptb
1314
python prepare_data.py --data yahoo
@@ -30,11 +31,12 @@ Here:
3031

3132
## Generation
3233
Generating sentences with pre-trained model can be performed with the following command:
34+
3335
```shell
3436
python vae_train.py --config config_file --mode predict --model /path/to/model.ckpt --out /path/to/output
3537
```
3638

37-
Here `--model` specifies the saved model checkpoint, which is saved in `./models/dataset_name/` at training time. For example, the model path is `./models/ptb/ptb_lstmDecoder.ckpt` when generating with a LSTM decoder trained on PTB dataset. Generated sentences will be written to standard output if `--out` is not specifcied.
39+
Here `--model` specifies the saved model checkpoint, which is saved in `./models/dataset_name/` at training time. For example, the model path is `./models/ptb/ptb_lstmDecoder.ckpt` when generating with a LSTM decoder trained on PTB dataset. Generated sentences will be written to standard output if `--out` is not specified.
3840

3941
## Results
4042

examples/vae_text/config_lstm_ptb.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
"""VAE config.
1615
"""
1716

examples/vae_text/config_lstm_yahoo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 The Texar Authors. All Rights Reserved.
1+
# Copyright 2019 The Texar Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
"""VAE config.
1615
"""
1716

examples/vae_text/config_trans_ptb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Config file of VAE with Trasnformer decoder, on PTB data.
14+
"""Config file of VAE with Transformer decoder, on PTB data.
1515
"""
1616

1717
dataset = 'ptb'

examples/vae_text/config_trans_yahoo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 The Texar Authors. All Rights Reserved.
1+
# Copyright 2019 The Texar Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
"""VAE config.
1615
"""
1716

examples/vae_text/vae_train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ def kl_divergence(means: Tensor, logvars: Tensor) -> Tensor:
6868
class VAE(nn.Module):
6969
_latent_z: Tensor
7070

71-
def __init__(self,
72-
vocab_size: int, config_model):
71+
def __init__(self, vocab_size: int, config_model):
7372
super().__init__()
7473
# Model architecture
7574
self._config = config_model
@@ -88,7 +87,7 @@ def __init__(self,
8887
if config_model.decoder_type == "lstm":
8988
self.lstm_decoder = tx.modules.BasicRNNDecoder(
9089
input_size=(self.decoder_w_embedder.dim +
91-
config_model.batch_size),
90+
config_model.latent_dims),
9291
vocab_size=vocab_size,
9392
token_embedder=self._embed_fn_rnn,
9493
hparams={"rnn_cell": config_model.dec_cell_hparams})
@@ -362,6 +361,9 @@ def _generate(start_tokens: torch.LongTensor,
362361
latent_z=latent_z,
363362
max_decoding_length=100)
364363

364+
if config.decoder_type == "transformer":
365+
outputs = outputs[0]
366+
365367
sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu())
366368

367369
if filename is None:

texar/torch/custom/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2019 The Texar Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""
215
Custom modules in Texar
316
"""

texar/torch/custom/activation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2019 The Texar Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""
215
Custom activation functions used in various methods.
316
"""

texar/torch/custom/distributions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2019 The Texar Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
from torch.distributions import Normal, Independent
215

316

0 commit comments

Comments
 (0)