Skip to content

Commit 282e0e5

Browse files
committed
fix shape issues, improve performance, make it pip installable, etc
1 parent a4debae commit 282e0e5

File tree

13 files changed

+67
-25
lines changed

13 files changed

+67
-25
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,8 @@ ENV/
9696
/site
9797

9898
# mypy
99-
.mypy_cache/
99+
.mypy_cache/
100+
101+
# vscode and its extensions
102+
.vscode/*
103+
.history/*

dnc/__init__.py

Whitespace-only changes.

access.py renamed to dnc/access.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import sonnet as snt
2323
import tensorflow as tf
2424

25-
import addressing
26-
import util
25+
from dnc import addressing
26+
from dnc import util
2727

2828
AccessState = collections.namedtuple('AccessState', (
2929
'memory', 'read_weights', 'write_weights', 'linkage', 'usage'))
@@ -53,7 +53,7 @@ def _erase_and_write(memory, address, reset_weights, values):
5353
expand_address = tf.expand_dims(address, 3)
5454
reset_weights = tf.expand_dims(reset_weights, 2)
5555
weighted_resets = expand_address * reset_weights
56-
reset_gate = tf.reduce_prod(1 - weighted_resets, [1])
56+
reset_gate = util.reduce_prod(1 - weighted_resets, 1)
5757
memory *= reset_gate
5858

5959
with tf.name_scope('additive_write', values=[memory, address, values]):

access_test.py renamed to dnc/access_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import tensorflow as tf
2323
from tensorflow.python.ops import rnn
2424

25-
import access
26-
import util
25+
from dnc import access
26+
from dnc import util
2727

2828
BATCH_SIZE = 2
2929
MEMORY_SIZE = 20

addressing.py renamed to dnc/addressing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import sonnet as snt
2323
import tensorflow as tf
2424

25-
import util
25+
from dnc import util
2626

2727
# Ensure values are greater than epsilon to avoid numerical instability.
2828
_EPSILON = 1e-6
@@ -32,7 +32,7 @@
3232

3333

3434
def _vector_norms(m):
35-
squared_norms = tf.reduce_sum(m * m, axis=2, keep_dims=True)
35+
squared_norms = tf.reduce_sum(m * m, axis=2, keepdims=True)
3636
return tf.sqrt(squared_norms + _EPSILON)
3737

3838

@@ -202,7 +202,7 @@ def _link(self, prev_link, prev_precedence_weights, write_weights):
202202
containing the new link graphs for each write head.
203203
"""
204204
with tf.name_scope('link'):
205-
batch_size = prev_link.get_shape()[0].value
205+
batch_size = tf.shape(prev_link)[0]
206206
write_weights_i = tf.expand_dims(write_weights, 3)
207207
write_weights_j = tf.expand_dims(write_weights, 2)
208208
prev_precedence_weights_j = tf.expand_dims(prev_precedence_weights, 2)
@@ -236,7 +236,7 @@ def _precedence_weights(self, prev_precedence_weights, write_weights):
236236
new precedence weights.
237237
"""
238238
with tf.name_scope('precedence_weights'):
239-
write_sum = tf.reduce_sum(write_weights, 2, keep_dims=True)
239+
write_sum = tf.reduce_sum(write_weights, 2, keepdims=True)
240240
return (1 - write_sum) * prev_precedence_weights + write_weights
241241

242242
@property
@@ -351,7 +351,7 @@ def _usage_after_write(self, prev_usage, write_weights):
351351
"""
352352
with tf.name_scope('usage_after_write'):
353353
# Calculate the aggregated effect of all write heads
354-
write_weights = 1 - tf.reduce_prod(1 - write_weights, [1])
354+
write_weights = 1 - util.reduce_prod(1 - write_weights, 1)
355355
return prev_usage + (1 - prev_usage) * write_weights
356356

357357
def _usage_after_read(self, prev_usage, free_gate, read_weights):
@@ -370,7 +370,7 @@ def _usage_after_read(self, prev_usage, free_gate, read_weights):
370370
with tf.name_scope('usage_after_read'):
371371
free_gate = tf.expand_dims(free_gate, -1)
372372
free_read_weights = free_gate * read_weights
373-
phi = tf.reduce_prod(1 - free_read_weights, [1], name='phi')
373+
phi = util.reduce_prod(1 - free_read_weights, 1, name='phi')
374374
return prev_usage * phi
375375

376376
def _allocation(self, usage):
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import sonnet as snt
2323
import tensorflow as tf
2424

25-
import addressing
26-
import util
25+
from dnc import addressing
26+
from dnc import util
2727

2828

2929
class WeightedSoftmaxTest(tf.test.TestCase):

dnc.py renamed to dnc/dnc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import sonnet as snt
2828
import tensorflow as tf
2929

30-
import access
30+
from dnc import access
3131

3232
DNCState = collections.namedtuple('DNCState', ('access_output', 'access_state',
3333
'controller_state'))
@@ -110,7 +110,7 @@ def _build(self, inputs, prev_state):
110110
controller_input, prev_controller_state)
111111

112112
controller_output = self._clip_if_enabled(controller_output)
113-
controller_state = snt.nest.map(self._clip_if_enabled, controller_state)
113+
controller_state = tf.contrib.framework.nest.map_structure(self._clip_if_enabled, controller_state)
114114

115115
access_output, access_state = self._access(controller_output,
116116
prev_access_state)
File renamed without changes.

util.py renamed to dnc/util.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,48 @@
2525
def batch_invert_permutation(permutations):
2626
"""Returns batched `tf.invert_permutation` for every row in `permutations`."""
2727
with tf.name_scope('batch_invert_permutation', values=[permutations]):
28-
unpacked = tf.unstack(permutations)
29-
inverses = [tf.invert_permutation(permutation) for permutation in unpacked]
30-
return tf.stack(inverses)
28+
perm = tf.cast(permutations, tf.float32)
29+
dim = int(perm.get_shape()[-1])
30+
size = tf.cast(tf.shape(perm)[0], tf.float32)
31+
delta = tf.cast(tf.shape(perm)[-1], tf.float32)
32+
rg = tf.range(0, size * delta, delta, dtype=tf.float32)
33+
rg = tf.expand_dims(rg, 1)
34+
rg = tf.tile(rg, [1, dim])
35+
perm = tf.add(perm, rg)
36+
flat = tf.reshape(perm, [-1])
37+
perm = tf.invert_permutation(tf.cast(flat, tf.int32))
38+
perm = tf.reshape(perm, [-1, dim])
39+
return tf.subtract(perm, tf.cast(rg, tf.int32))
3140

3241

3342
def batch_gather(values, indices):
3443
"""Returns batched `tf.gather` for every row in the input."""
3544
with tf.name_scope('batch_gather', values=[values, indices]):
36-
unpacked = zip(tf.unstack(values), tf.unstack(indices))
37-
result = [tf.gather(value, index) for value, index in unpacked]
38-
return tf.stack(result)
45+
idx = tf.expand_dims(indices, -1)
46+
size = tf.shape(indices)[0]
47+
rg = tf.range(size, dtype=tf.int32)
48+
rg = tf.expand_dims(rg, -1)
49+
rg = tf.tile(rg, [1, int(indices.get_shape()[-1])])
50+
rg = tf.expand_dims(rg, -1)
51+
gidx = tf.concat([rg, idx], -1)
52+
return tf.gather_nd(values, gidx)
3953

4054

4155
def one_hot(length, index):
4256
"""Return an nd array of given `length` filled with 0s and a 1 at `index`."""
4357
result = np.zeros(length)
4458
result[index] = 1
4559
return result
60+
61+
def reduce_prod(x, axis, name=None):
62+
"""Efficient reduce product over axis.
63+
64+
Uses tf.cumprod and tf.gather_nd as a workaround to the poor performance of calculating tf.reduce_prod's gradient on CPU.
65+
"""
66+
with tf.name_scope(name, 'util_reduce_prod', values=[x]):
67+
cp = tf.cumprod(x, axis, reverse=True)
68+
size = tf.shape(cp)[0]
69+
idx1 = tf.range(tf.cast(size, tf.float32), dtype=tf.float32)
70+
idx2 = tf.zeros([size], tf.float32)
71+
indices = tf.stack([idx1, idx2], 1)
72+
return tf.gather_nd(cp, tf.cast(indices, tf.int32))

0 commit comments

Comments
 (0)