-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathstochastic_variables.py
More file actions
83 lines (61 loc) · 3.31 KB
/
stochastic_variables.py
File metadata and controls
83 lines (61 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import tensorflow as tf
import math
from tensorflow.contrib.rnn import BasicLSTMCell, LSTMStateTuple
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _checked_scope
def gaussian_mixture_nll(samples, mixing_weights, mean1, mean2, std1, std2):
"""
Computes the NLL from a mixture of two gaussian distributions with the given
means and standard deviations, mixing weights and samples.
"""
gaussian1 = (1.0/tf.sqrt(2.0 * std1 * math.pi)) * tf.exp(- tf.square(samples - mean1) / (2.0 * std1))
gaussian2 = (1.0/tf.sqrt(2.0 * std2 * math.pi)) * tf.exp(- tf.square(samples - mean2) / (2.0 * std2))
mixture = (mixing_weights[0] * gaussian1) + (mixing_weights[1] * gaussian2)
return - tf.log(mixture)
def get_random_normal_variable(name, mean, standard_dev, shape, dtype):
"""
A wrapper around tf.get_variable which lets you get a "variable" which is
explicitly a sample from a normal distribution.
"""
# Inverse of a softplus function, so that the value of the standard deviation
# will be equal to what the user specifies, but we can still enforce positivity
# by wrapping the standard deviation in the softplus function.
standard_dev = tf.log(tf.exp(standard_dev) - 1.0) * tf.ones(shape)
mean = tf.get_variable(name + "_mean", shape,
initializer=tf.constant_initializer(mean),
dtype=dtype)
standard_deviation = tf.get_variable(name + "_standard_deviation",
initializer=standard_dev,
dtype=dtype)
standard_deviation = tf.nn.softplus(standard_deviation)
weights = mean + (standard_deviation * tf.random_normal(shape, 0.0, 1.0, dtype))
return weights, mean, standard_deviation
class ExternallyParameterisedLSTM(BasicLSTMCell):
"""
A simple extension of an LSTM in which the weights are passed in to the class,
rather than being automatically generated inside the cell when it is called.
This allows us to parameterise them in other, funky ways.
"""
def __init__(self, weight, bias, **kwargs):
self.weight = weight
self.bias = bias
super(ExternallyParameterisedLSTM, self).__init__(**kwargs)
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with _checked_scope(self, scope or "basic_lstm_cell", reuse=self._reuse):
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = tf.split(value=state, num_or_size_splits=2, axis=1)
all_inputs = tf.concat([inputs, h], 1)
concat = tf.nn.bias_add(tf.matmul(all_inputs, self.weight), self.bias)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1)
new_c = (c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * tf.sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = tf.concat([new_c, new_h], 1)
return new_h, new_state