-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgen-user.py
More file actions
38 lines (29 loc) · 913 Bytes
/
gen-user.py
File metadata and controls
38 lines (29 loc) · 913 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from generator import RNNLayerGenerator
from train import TrainerFactory
def train():
trainer = TrainerFactory.get_trainer(trainer_type="layer",
root_dir='./data',
epochs=300,
batch_size=128,
lr=0.0001,
device="cpu",
logfile="train_loss.log",
verbose=1)
trainer.run_train_loop()
def generate():
number = 5
race = ''
gender = ''
mpath = './models/rnn_layer_epoch_250.pt'
dnd = RNNLayerGenerator(model_path=mpath)
tuples = dnd.generate(number, race, gender)
for name_tuple in tuples:
print (name_tuple[0] + ': ' +name_tuple[2] + ' ' + name_tuple[1])
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-train")
args = parser.parse_args()
if args.train:
train()
generate()