diff --git a/capsLayer.py b/capsLayer.py index 136dca7..cf88a5a 100644 --- a/capsLayer.py +++ b/capsLayer.py @@ -69,12 +69,12 @@ def __call__(self, input, kernel_size=None, stride=None): if self.with_routing: # the DigitCaps layer, a fully connected layer # Reshape the input into [batch_size, 1152, 1, 8, 1] - self.input = tf.reshape(input, shape=(cfg.batch_size, -1, 1, input.shape[-2].value, 1)) + self.input = tf.reshape(input, shape=(cfg.batch_size, -1, 1, input.get_shape()[-2].value, 1)) with tf.variable_scope('routing'): # b_IJ: [batch_size, num_caps_l, num_caps_l_plus_1, 1, 1], # about the reason of using 'batch_size', see issue #21 - b_IJ = tf.constant(np.zeros([cfg.batch_size, input.shape[1].value, self.num_outputs, 1, 1], dtype=np.float32)) + b_IJ = tf.constant(np.zeros([cfg.batch_size, input.get_shape()[1].value, self.num_outputs, 1, 1], dtype=np.float32)) capsules = routing(self.input, b_IJ, num_outputs=self.num_outputs, num_dims=self.vec_len) capsules = tf.squeeze(capsules, axis=1)