-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathgenerate.jl
More file actions
151 lines (116 loc) · 3.5 KB
/
generate.jl
File metadata and controls
151 lines (116 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
##########################################################################################
# Author: Jared L. Ostmeyer
# Date Started: 2015-12-23
# Environment: Julia v0.4
# Purpose: Sample restricted Boltzmann machine as a generative model.
##########################################################################################
##########################################################################################
# Packages
##########################################################################################
# Load tools for sampling N-way outcomes according to a set of weights.
#
using StatsBase
# Load tools for saving images.
#
using Images
##########################################################################################
# Settings
##########################################################################################
# Schedule for updating the neural network.
#
N_equilibrate = 10000
N_samples = 100
# Number of neurons in each layer.
#
N_x = 28^2
N_z = 10
N_h = 500
# Load neural network parameters.
#
b_x = readcsv("bin/train_b_x.csv")
W_xh = readcsv("bin/train_W_xh.csv")
b_z = readcsv("bin/train_b_z.csv")
W_zh = readcsv("bin/train_W_zh.csv")
b_h = readcsv("bin/train_b_h.csv")
# Randomly initialize each layer.
#
px_s = rand(N_x, N_samples)
pz_s = rand(N_z, N_samples)
ph_s = rand(N_h, N_samples)
# Image dimensions.
#
height = 28
width = 28
##########################################################################################
# Methods
##########################################################################################
# Activation functions.
#
sigmoid(x) = 1.0./(1.0+exp(-x))
softmax(x) = exp(x)./sum(exp(x))
# Sampling methods.
#
state(p) = 1.0*(rand(size(p)) .<= p)
choose(p) = ( y = zeros(size(p)) ; for i = 1:size(p, 2) j = sample(WeightVec(p[:,i])) ; y[j,i] = 1.0 end ; y )
##########################################################################################
# Generate
##########################################################################################
# Repeatedly sample the model.
#
for i = 1:N_samples
# Sample initial state of each layer.
#
x = state(px_s[:,i])
z = choose(pz_s[:,i])
h = state(ph_s[:,i])
# Repeated passes of Gibbs sampling.
#
for j = 1:N_equilibrate
px_s[:,i] = sigmoid(W_xh*h+b_x)
x = state(px_s[:,i])
pz_s[:,i] = softmax(W_zh*h+b_z)
z = choose(pz_s[:,i])
ph_s[:,i] = sigmoid(W_xh'*x+W_zh'*z+b_h)
h = state(ph_s[:,i])
end
end
##########################################################################################
# Figure
##########################################################################################
# Samples will rendered in tiles by rows and columns.
#
N_columns = ceil(Int, sqrt(N_samples))
N_rows = ceil(Int, N_samples/N_columns)
# Pixels on which the tiles will be rendered.
#
pixels = zeros(height*N_rows, width*N_columns)
# Index of the tiles.
#
k = 1
# Loop over the rows of the tiles.
#
for i = 1:N_rows
# Loop over the columns of the tiles.
#
for j = 1:N_columns
# Indices for selecting the current tile.
#
is = (i-1)*height+1:i*height
js = (j-1)*width+1:j*width
# Copy the sample into the tile.
#
pixels[is, js] = reshape(px_s[:,k], height, width)
# Update the index for the tile.
#
k += 1
if k > N_samples
break
end
end
end
# Create grayscale image with black background.
#
figure = convert(Image, pixels)
# Save the tiles as an image file.
#
save("generate.png", figure)