-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
32 lines (26 loc) · 798 Bytes
/
main.py
File metadata and controls
32 lines (26 loc) · 798 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gpt_2_simple as gpt2
import os
model_name = "355M"
if not os.path.isdir(os.path.join("models", model_name)):
print(f"Downloading {model_name} model...")
gpt2.download_gpt2(model_name=model_name)
name_of_artist = 'travis'
file_name = f"training_data/{name_of_artist}.txt"
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
dataset=file_name,
model_name='355M',
steps=1000,
restore_from='fresh',
run_name=f'{name_of_artist}',
print_every=1,
sample_every=200,
save_every=200
)
gpt2.generate(sess,
length=500,
temperature=0.8,
prefix="It's lit",
nsamples=25,
batch_size=5
)