Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,6 @@ ENV/

# Vagrant
.vagrant/

# Pngs
*.png
7 changes: 4 additions & 3 deletions train/dataprepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ class Lang:

"""

def __init__(self, name):
def __init__(self, name, threshold=1):
"""Init Lang with a name."""
# Ken added <EOB> on 04/04/2018
self.name = name
self.word2index = {"<SOS>": 0, "<EOS>": 1, "<PAD>": 2, "<UNK>": 3, "<EOB>": 4, "<BLK>": 5}
self.word2count = {"<EOS>": 0, "<PAD>": 0, "<EOB>": 0, "<BLK>": 0}
self.index2word = {0: "<SOS>", 1: "<EOS>", 2: "<PAD>", 3: "<UNK>", 4: "<EOB>", 5: "<BLK>"}
self.threshold = threshold
self.n_words = 5 # Count SOS and EOS

def addword(self, word):
Expand All @@ -45,7 +46,7 @@ def addword(self, word):
self.n_words += 1
else:
self.word2count[word] += 1


def readLang(data_set):
"""The function to wrap up a data_set.
Expand Down Expand Up @@ -157,7 +158,7 @@ def findword2index(lang, word):
# data_set[i].append([idx_triplets] + [idx_summary])
data_set[i].idx_data = [idx_triplets] + [idx_summary]
data_set[i].sent_leng = sentence_cnt

return data_set


Expand Down
20 changes: 10 additions & 10 deletions train/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
use_cuda = torch.cuda.is_available()

MAX_LENGTH = 800
LAYER_DEPTH = 2
MAX_SENTENCES = 5
MAX_TRAIN_NUM = 200
LAYER_DEPTH = 4
MAX_SENTENCES = None
MAX_TRAIN_NUM = None

Model_name = None
#Model_name = 'pretrain_ms8'
iterNum = 500
Model_name = 'pl'
iterNum = 8495
USE_MODEL = None
if Model_name is not None:
USE_MODEL = ['./models/'+Model_name + '_' + s + '_' + str(iterNum) for s in ['encoder', 'decoder', 'optim']]
Expand All @@ -25,23 +25,23 @@
# LR = 0.003 # Adam
ITER_TIME = 220
BATCH_SIZE = 2
GRAD_CLIP = 5
GRAD_CLIP = 2

# Parameter for display
GET_LOSS = 1
SAVE_MODEL = 5
GET_LOSS = 10
SAVE_MODEL = 1

# Choose models

# ENCODER_STYLE = 'LIN'
ENCODER_STYLE = 'BiLSTM'
ENCODER_STYLE = 'BiLSTMMax'
#ENCODER_STYLE = 'RNN'
DECODER_STYLE = 'RNN'

# ENCODER_STYLE = 'HierarchicalBiLSTM'
#ENCODER_STYLE = 'HierarchicalRNN'
#DECODER_STYLE = 'HierarchicalRNN'
OUTPUT_FILE = 'pretrain_copy_ms5'
OUTPUT_FILE = 'copy'
COPY_PLAYER = True
TOCOPY = True

Expand Down
Loading