1+ import numpy as np
2+ import torch
3+
4+ class BaseDifferentiableImageGenerator (object ):
5+
6+ """Base class for all image generators that are written in pytorch and differentiable into theta.
7+ All classes should output an image of ImageNet dimensions (224x224).
8+ """
9+
10+ def __init__ (self ):
11+
12+ pass
13+
14+ def generate_image (self , theta , context ):
15+ """To be redefined in any subclass.
16+
17+ Deterministically generates an image as a function of theta and context.
18+
19+ Context and theta should both smoothly relate to the image. Small changes in
20+ theta and changes in context should cause small changes in the output image.
21+ For two values of theta but the same context, the two output images should be
22+ the same in all manners except for theta.
23+
24+ Methods need to be written in pytorch, such that they take a Variable and produce an
25+ image that is differentiable with respect to theta
26+
27+ """
28+ raise NotImplementedError ()
29+
30+
31+ class BaseNonDifferentiableImageGenerator (object ):
32+ """Base class for all image generators that aren't differentiable.
33+
34+ All classes should output an image of ImageNet dimensions (224x224).
35+
36+ When possible, all methods should still use torch methods and not numpy or scipy methods, for speed."""
37+
38+ def __init__ (self ):
39+ pass
40+
41+ def generate_image (self , theta , context ):
42+ """To be redefined in any subclass.
43+
44+ Deterministically generates an image as a function of theta and context.
45+
46+ Context and theta should both smoothly relate to the image. Small changes in
47+ theta and changes in context should cause small changes in the output image.
48+ For two values of theta but the same context, the two output images should be
49+ the same in all manners except for theta.
50+
51+ """
52+ raise NotImplementedError ()
53+
54+ class OneCurvedLineGenerator (BaseDifferentiableImageGenerator ):
55+ """This is designed to create illusions like the Herring illusion.
56+ """
57+
58+ def __init__ (self , n_lines = 5 ):
59+ super (OneCurvedLineGenerator , self ).__init__ ()
60+ self .n_lines = n_lines
61+
62+
63+ def generate_image (self , theta , context ):
64+ """
65+ Theta is the curvature of the central horizontal line.
66+ Context are parameters that describe the overlaid lines.
67+ e.g. a list of midpoints and orientations, for a total of n_lines x 3 parameters. """
68+
69+ raise NotImplementedError ()
70+
71+
72+
73+ class CentralPixelGenerator (BaseDifferentiableImageGenerator ):
74+ """This is designed to generate images with a central block of pixels whose lightness is an illusions.
75+ """
76+
77+ def __init__ (self , n_pixels_blocks_per_side = 5 ):
78+ """
79+
80+
81+ :param n_pixels_blocks_per_side: an odd integer
82+ """
83+ super (CentralPixelGenerator , self ).__init__ ()
84+ assert n_pixels_blocks_per_side % 2 == 1 , "n_pixels_blocks_per_side must be odd"
85+
86+ self .n_pixels_blocks_per_side = n_pixels_blocks_per_side
87+
88+
89+ def generate_image (self , theta , context ):
90+ """
91+ :param theta: the lightness of the central pixel block.
92+ :param context: the lightness of the surrounding pixels.
93+ :return: A 224x224x3 image
94+ """
95+ assert len (context ) == self .n_pixels_blocks_per_side ** 2 - 1
96+
97+ raise NotImplementedError ()
0 commit comments