1616import tensorflow as tf
1717
1818
19- def conv_layer (input , num_input_channels ,
20- filter_height , filter_width ,
21- num_filters , seed = None , use_pooling = True ):
19+ def conv_layer (
20+ input ,
21+ num_input_channels ,
22+ filter_height ,
23+ filter_width ,
24+ num_filters ,
25+ seed = None ,
26+ use_pooling = True ,
27+ ):
2228 shape = [filter_height , filter_width , num_input_channels , num_filters ]
2329 weights = tf .Variable (tf .truncated_normal (shape , stddev = 0.05 , seed = seed ))
2430 biases = tf .Variable (tf .constant (0.05 , shape = [num_filters ]))
25- layer = tf .nn .conv2d (input = input , filter = weights ,
26- strides = [1 , 1 , 1 , 1 ], padding = "VALID" )
31+ layer = tf .nn .conv2d (
32+ input = input , filter = weights , strides = [1 , 1 , 1 , 1 ], padding = "VALID"
33+ )
2734 layer = layer + biases
2835 if use_pooling :
29- layer = tf .nn .max_pool (value = layer , ksize = [1 , input .shape [1 ] - filter_height + 1 , 1 , 1 ],
30- strides = [1 , 1 , 1 , 1 ], padding = "VALID" )
36+ layer = tf .nn .max_pool (
37+ value = layer ,
38+ ksize = [1 , input .shape [1 ] - filter_height + 1 , 1 , 1 ],
39+ strides = [1 , 1 , 1 , 1 ],
40+ padding = "VALID" ,
41+ )
3142 layer = tf .nn .relu (layer )
3243 return layer , weights
3344
@@ -40,28 +51,39 @@ def flatten_layer(layer):
4051
4152
4253def fc_layer (input , num_input , num_output , seed = None ):
43- weights = tf .Variable (tf .truncated_normal ([num_input , num_output ], stddev = 0.05 , seed = seed ))
54+ weights = tf .Variable (
55+ tf .truncated_normal ([num_input , num_output ], stddev = 0.05 , seed = seed )
56+ )
4457 biases = tf .Variable (tf .constant (0.05 , shape = [num_output ]))
4558 layer = tf .matmul (input , weights ) + biases
4659 layer = tf .nn .tanh (layer )
4760 return layer
4861
4962
50- class CNN_module ():
51-
52- def __init__ (self , output_dimension , dropout_rate ,
53- emb_dim , max_len , nb_filters , seed ,
54- init_W , learning_rate = 0.001 ):
63+ class CNN_module :
64+ def __init__ (
65+ self ,
66+ output_dimension ,
67+ dropout_rate ,
68+ emb_dim ,
69+ max_len ,
70+ filter_sizes ,
71+ num_filters ,
72+ hidden_dim ,
73+ seed ,
74+ init_W ,
75+ learning_rate = 0.001 ,
76+ ):
5577 self .drop_rate = dropout_rate
5678 self .max_len = max_len
5779 self .seed = seed
5880 self .learning_rate = learning_rate
5981 self .init_W = tf .constant (init_W )
6082 self .output_dimension = output_dimension
6183 self .emb_dim = emb_dim
62- self .nb_filters = nb_filters
63- self .filter_lengths = [ 3 , 4 , 5 ]
64- self .vanila_dimension = 200
84+ self .filter_lengths = filter_sizes
85+ self .nb_filters = num_filters
86+ self .vanila_dimension = hidden_dim
6587
6688 self ._build_graph ()
6789
@@ -76,28 +98,46 @@ def _build_graph(self):
7698 self .reshape = tf .reshape (self .seq_emb , [- 1 , self .max_len , self .emb_dim , 1 ])
7799 self .convs = []
78100
79- # Convolutional layer
101+ # Convolutional layers
80102 for i in self .filter_lengths :
81- convolutional_layer , weights = conv_layer (input = self .reshape , num_input_channels = 1 ,
82- filter_height = i , filter_width = self .emb_dim ,
83- num_filters = self .nb_filters , use_pooling = True )
103+ convolutional_layer , weights = conv_layer (
104+ input = self .reshape ,
105+ num_input_channels = 1 ,
106+ filter_height = i ,
107+ filter_width = self .emb_dim ,
108+ num_filters = self .nb_filters ,
109+ use_pooling = True ,
110+ )
84111
85112 flat_layer , _ = flatten_layer (convolutional_layer )
86113 self .convs .append (flat_layer )
87114
88115 self .model_output = tf .concat (self .convs , axis = - 1 )
89116 # Fully-connected layers
90- self .model_output = fc_layer (input = self .model_output , num_input = self .model_input .get_shape ()[1 ].value ,
91- num_output = self .vanila_dimension )
117+ self .model_output = fc_layer (
118+ input = self .model_output ,
119+ num_input = self .model_output .get_shape ()[- 1 ].value ,
120+ num_output = self .vanila_dimension ,
121+ )
92122 # Dropout layer
93123 self .model_output = tf .nn .dropout (self .model_output , self .drop_rate )
94124 # Output layer
95- self .model_output = fc_layer (input = self .model_output , num_input = self .vanila_dimension ,
96- num_output = self .output_dimension )
125+ self .model_output = fc_layer (
126+ input = self .model_output ,
127+ num_input = self .vanila_dimension ,
128+ num_output = self .output_dimension ,
129+ )
97130 # Weighted MEA loss function
98- self .mean_square_loss = tf .losses .mean_squared_error (labels = self .v , predictions = self .model_output ,
99- reduction = tf .losses .Reduction .NONE )
131+ self .mean_square_loss = tf .losses .mean_squared_error (
132+ labels = self .v ,
133+ predictions = self .model_output ,
134+ reduction = tf .losses .Reduction .NONE ,
135+ )
100136 self .weighted_loss = tf .reduce_sum (
101- tf .reduce_sum (self .mean_square_loss , axis = 1 , keepdims = True ) * self .sample_weight )
137+ tf .reduce_sum (self .mean_square_loss , axis = 1 , keepdims = True )
138+ * self .sample_weight
139+ )
102140 # RMSPro optimizer
103- self .optimizer = tf .train .RMSPropOptimizer (learning_rate = self .learning_rate ).minimize (self .weighted_loss )
141+ self .optimizer = tf .train .RMSPropOptimizer (
142+ learning_rate = self .learning_rate
143+ ).minimize (self .weighted_loss )
0 commit comments