|
| 1 | +""" This file contains different utility functions that are not connected |
| 2 | +in anyway to the networks presented in the tutorials, but rather help in |
| 3 | +processing the outputs into a more understandable way. |
| 4 | +
|
| 5 | +For example ``tile_raster_images`` helps in generating a easy to grasp |
| 6 | +image from a set of samples or weights. |
| 7 | +""" |
| 8 | + |
| 9 | + |
| 10 | +import numpy |
| 11 | +from six.moves import xrange |
| 12 | + |
| 13 | + |
| 14 | +def scale_to_unit_interval(ndar, eps=1e-8): |
| 15 | + """ Scales all values in the ndarray ndar to be between 0 and 1 """ |
| 16 | + ndar = ndar.copy() |
| 17 | + ndar -= ndar.min() |
| 18 | + ndar *= 1.0 / (ndar.max() + eps) |
| 19 | + return ndar |
| 20 | + |
| 21 | + |
| 22 | +def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0), |
| 23 | + scale_rows_to_unit_interval=True, |
| 24 | + output_pixel_vals=True): |
| 25 | + """ |
| 26 | + Transform an array with one flattened image per row, into an array in |
| 27 | + which images are reshaped and layed out like tiles on a floor. |
| 28 | +
|
| 29 | + This function is useful for visualizing datasets whose rows are images, |
| 30 | + and also columns of matrices for transforming those rows |
| 31 | + (such as the first layer of a neural net). |
| 32 | +
|
| 33 | + :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can |
| 34 | + be 2-D ndarrays or None; |
| 35 | + :param X: a 2-D array in which every row is a flattened image. |
| 36 | +
|
| 37 | + :type img_shape: tuple; (height, width) |
| 38 | + :param img_shape: the original shape of each image |
| 39 | +
|
| 40 | + :type tile_shape: tuple; (rows, cols) |
| 41 | + :param tile_shape: the number of images to tile (rows, cols) |
| 42 | +
|
| 43 | + :param output_pixel_vals: if output should be pixel values (i.e. int8 |
| 44 | + values) or floats |
| 45 | +
|
| 46 | + :param scale_rows_to_unit_interval: if the values need to be scaled before |
| 47 | + being plotted to [0,1] or not |
| 48 | +
|
| 49 | +
|
| 50 | + :returns: array suitable for viewing as an image. |
| 51 | + (See:`Image.fromarray`.) |
| 52 | + :rtype: a 2-d array with same dtype as X. |
| 53 | +
|
| 54 | + """ |
| 55 | + |
| 56 | + assert len(img_shape) == 2 |
| 57 | + assert len(tile_shape) == 2 |
| 58 | + assert len(tile_spacing) == 2 |
| 59 | + |
| 60 | + # The expression below can be re-written in a more C style as |
| 61 | + # follows : |
| 62 | + # |
| 63 | + # out_shape = [0,0] |
| 64 | + # out_shape[0] = (img_shape[0]+tile_spacing[0])*tile_shape[0] - |
| 65 | + # tile_spacing[0] |
| 66 | + # out_shape[1] = (img_shape[1]+tile_spacing[1])*tile_shape[1] - |
| 67 | + # tile_spacing[1] |
| 68 | + out_shape = [ |
| 69 | + (ishp + tsp) * tshp - tsp |
| 70 | + for ishp, tshp, tsp in zip(img_shape, tile_shape, tile_spacing) |
| 71 | + ] |
| 72 | + |
| 73 | + if isinstance(X, tuple): |
| 74 | + assert len(X) == 4 |
| 75 | + # Create an output numpy ndarray to store the image |
| 76 | + if output_pixel_vals: |
| 77 | + out_array = numpy.zeros((out_shape[0], out_shape[1], 4), |
| 78 | + dtype='uint8') |
| 79 | + else: |
| 80 | + out_array = numpy.zeros((out_shape[0], out_shape[1], 4), |
| 81 | + dtype=X.dtype) |
| 82 | + |
| 83 | + #colors default to 0, alpha defaults to 1 (opaque) |
| 84 | + if output_pixel_vals: |
| 85 | + channel_defaults = [0, 0, 0, 255] |
| 86 | + else: |
| 87 | + channel_defaults = [0., 0., 0., 1.] |
| 88 | + |
| 89 | + for i in xrange(4): |
| 90 | + if X[i] is None: |
| 91 | + # if channel is None, fill it with zeros of the correct |
| 92 | + # dtype |
| 93 | + dt = out_array.dtype |
| 94 | + if output_pixel_vals: |
| 95 | + dt = 'uint8' |
| 96 | + out_array[:, :, i] = numpy.zeros( |
| 97 | + out_shape, |
| 98 | + dtype=dt |
| 99 | + ) + channel_defaults[i] |
| 100 | + else: |
| 101 | + # use a recurrent call to compute the channel and store it |
| 102 | + # in the output |
| 103 | + out_array[:, :, i] = tile_raster_images( |
| 104 | + X[i], img_shape, tile_shape, tile_spacing, |
| 105 | + scale_rows_to_unit_interval, output_pixel_vals) |
| 106 | + return out_array |
| 107 | + |
| 108 | + else: |
| 109 | + # if we are dealing with only one channel |
| 110 | + H, W = img_shape |
| 111 | + Hs, Ws = tile_spacing |
| 112 | + |
| 113 | + # generate a matrix to store the output |
| 114 | + dt = X.dtype |
| 115 | + if output_pixel_vals: |
| 116 | + dt = 'uint8' |
| 117 | + out_array = numpy.zeros(out_shape, dtype=dt) |
| 118 | + |
| 119 | + for tile_row in xrange(tile_shape[0]): |
| 120 | + for tile_col in xrange(tile_shape[1]): |
| 121 | + if tile_row * tile_shape[1] + tile_col < X.shape[0]: |
| 122 | + this_x = X[tile_row * tile_shape[1] + tile_col] |
| 123 | + if scale_rows_to_unit_interval: |
| 124 | + # if we should scale values to be between 0 and 1 |
| 125 | + # do this by calling the `scale_to_unit_interval` |
| 126 | + # function |
| 127 | + this_img = scale_to_unit_interval( |
| 128 | + this_x.reshape(img_shape)) |
| 129 | + else: |
| 130 | + this_img = this_x.reshape(img_shape) |
| 131 | + # add the slice to the corresponding position in the |
| 132 | + # output array |
| 133 | + c = 1 |
| 134 | + if output_pixel_vals: |
| 135 | + c = 255 |
| 136 | + out_array[ |
| 137 | + tile_row * (H + Hs): tile_row * (H + Hs) + H, |
| 138 | + tile_col * (W + Ws): tile_col * (W + Ws) + W |
| 139 | + ] = this_img * c |
| 140 | + return out_array |
0 commit comments