@@ -158,32 +158,13 @@ class functions:
158158 def elu (Node x , float a ):
159159 return wrapNode(func_elu(x.wrapped, a))
160160
161- @staticmethod
162- def selu (Node x , float a , float s ):
163- return wrapNode(func_selu(x.wrapped, a, s))
164-
165161 @staticmethod
166162 def sum (x , dim = None ):
167- cdef vector[CppNode] xs
168- cdef Node node
169163 if isinstance (x, list ):
170- for node in x:
171- xs.push_back(node.wrapped)
172- return wrapNode(func_sum(xs))
164+ return functions.sum_list(x)
173165 else :
174166 return wrapNode(func_sum((< Node> x).wrapped, < unsigned > dim))
175167
176- @staticmethod
177- def mean (x , dim = None ):
178- cdef vector[CppNode] xs
179- cdef Node node
180- if isinstance (x, list ):
181- for node in x:
182- xs.push_back(node.wrapped)
183- return wrapNode(func_mean(xs))
184- else :
185- return wrapNode(func_mean((< Node> x).wrapped, < unsigned > dim))
186-
187168 @staticmethod
188169 def broadcast (Node x , unsigned dim , unsigned size ):
189170 return wrapNode(func_broadcast(x.wrapped, dim, size))
@@ -243,43 +224,76 @@ class functions:
243224 get_cpp_device(device), get_cpp_graph(graph)))
244225
245226 @staticmethod
246- def zeros ( shape , Device device = None , Graph graph = None ):
227+ def identity (unsigned size , Device device = None , Graph graph = None ):
247228 if device is None :
248229 device = Device.get_default()
249230 if graph is None :
250231 graph = Graph.get_default()
251- return wrapNode(func_zeros_node(normShape(shape).wrapped,
252- get_cpp_device(device), get_cpp_graph(graph)))
232+ return wrapNode(func_identity_node(size, get_cpp_device(device), get_cpp_graph(graph)))
233+
234+ # contrib functions
253235
254236 @staticmethod
255- def ones (shape , Device device = None , Graph graph = None ):
256- if device is None :
257- device = Device.get_default()
258- if graph is None :
259- graph = Graph.get_default()
260- return wrapNode(func_ones_node(normShape(shape).wrapped,
261- get_cpp_device(device), get_cpp_graph(graph)))
237+ def selu (Node x , float a = 1.6732632423543772848170429916717 , float s = 1.0507009873554804934193349852946 ):
238+ return s * functions.elu(x, a);
262239
263240 @staticmethod
264- def identity (unsigned size , Device device = None , Graph graph = None ):
265- if device is None :
266- device = Device.get_default()
267- if graph is None :
268- graph = Graph.get_default()
269- return wrapNode(func_identity_node(size, get_cpp_device(device), get_cpp_graph(graph)))
241+ def sum_list (list xs ):
242+ if not xs:
243+ raise TypeError (" No nodes to sum." )
244+ ret = xs[0 ]
245+ for x in xs[1 :]:
246+ ret = ret + x
247+ return ret
248+
249+ @staticmethod
250+ def mean (x , dim = None ):
251+ if isinstance (x, list ):
252+ return functions.sum_list(x) / len (x)
253+ else :
254+ return functions.sum(x, dim) / x.shape()[dim]
255+
256+ @staticmethod
257+ def zeros (shape , Device dev = None , Graph g = None ):
258+ return functions.constant(shape, 0.0 , dev, g)
259+
260+ @staticmethod
261+ def ones (shape , Device dev = None , Graph g = None ):
262+ return functions.constant(shape, 1.0 , dev, g)
263+
264+ @staticmethod
265+ def dropout (Node x , float rate , bool enabled ):
266+ if not enabled:
267+ return x
268+ if rate == 1.0 :
269+ return 0.0 * x
270+ p = 1.0 - rate
271+ return (1.0 / p) * x * functions.random.bernoulli(x.shape(), p, x.device())
272+
273+ # end contrib functions
270274
271275 class batch :
272276 @staticmethod
273277 def sum (Node x ):
274278 return wrapNode(func_batch_sum[CppNode](x.wrapped))
275279
280+ # contrib functions
281+
276282 @staticmethod
277283 def mean (Node x ):
278- return wrapNode(func_batch_mean[CppNode](x.wrapped) )
284+ return functions.batch.sum(x) / x.shape().batch( )
279285
280286 @staticmethod
281287 def normalize (Node x ):
282- return wrapNode(func_batch_normalize[CppNode](x.wrapped))
288+ if not x.shape().has_batch():
289+ return x
290+ b = x.shape().batch()
291+ scale = b / (b - 1 )
292+ m = functions.batch.mean(x)
293+ v = scale * (functions.batch.mean(x * x) - m * m)
294+ return (x - m) / functions.sqrt(v + 1e-8 )
295+
296+ # end contrib functions
283297
284298 class random :
285299 @staticmethod
@@ -327,10 +341,6 @@ class functions:
327341 return wrapNode(func_random_gumbel_node(normShape(shape).wrapped, mu, beta,
328342 get_cpp_device(device), get_cpp_graph(graph)))
329343
330- @staticmethod
331- def dropout (Node x , float rate , bool enabled ):
332- return wrapNode(func_dropout(x.wrapped, rate, enabled))
333-
334344
335345class tensor_functions :
336346
@@ -472,32 +482,13 @@ class tensor_functions:
472482 def elu (Tensor x , float a ):
473483 return Tensor.get_wrapper_with_new(new CppTensor(func_elu(x.wrapped[0 ], a)))
474484
475- @staticmethod
476- def selu (Tensor x , float a , float s ):
477- return Tensor.get_wrapper_with_new(new CppTensor(func_selu(x.wrapped[0 ], a, s)))
478-
479485 @staticmethod
480486 def sum (x , dim = None ):
481- cdef vector[CppTensor] xs
482- cdef Tensor t
483487 if isinstance (x, list ):
484- for t in x:
485- xs.push_back(t.wrapped[0 ])
486- return Tensor.get_wrapper_with_new(new CppTensor(func_sum(xs)))
488+ return tensor_functions.sum_list(x)
487489 else :
488490 return Tensor.get_wrapper_with_new(new CppTensor(func_sum((< Tensor> x).wrapped[0 ], < unsigned > dim)))
489491
490- @staticmethod
491- def mean (x , dim = None ):
492- cdef vector[CppTensor] xs
493- cdef Tensor t
494- if isinstance (x, list ):
495- for t in x:
496- xs.push_back(t.wrapped[0 ])
497- return Tensor.get_wrapper_with_new(new CppTensor(func_mean(xs)))
498- else :
499- return Tensor.get_wrapper_with_new(new CppTensor(func_mean((< Tensor> x).wrapped[0 ], < unsigned > dim)))
500-
501492 @staticmethod
502493 def broadcast (Tensor x , unsigned dim , unsigned size ):
503494 return Tensor.get_wrapper_with_new(new CppTensor(func_broadcast(x.wrapped[0 ], dim, size)))
@@ -554,38 +545,76 @@ class tensor_functions:
554545 return Tensor.get_wrapper_with_new(new CppTensor(func_constant_tensor(normShape(shape).wrapped, k,
555546 get_cpp_device(device))))
556547
548+
557549 @staticmethod
558- def zeros ( shape , Device device = None ):
550+ def identity (unsigned size , Device device = None ):
559551 if device is None :
560552 device = Device.get_default()
561- return Tensor.get_wrapper_with_new(new CppTensor(func_zeros_tensor(normShape(shape).wrapped,
562- get_cpp_device(device))))
553+ return Tensor.get_wrapper_with_new(new CppTensor(func_identity_tensor(size, get_cpp_device(device))))
554+
555+ # contrib functions
563556
564557 @staticmethod
565- def ones (shape , Device device = None ):
566- if device is None :
567- device = Device.get_default()
568- return Tensor.get_wrapper_with_new(new CppTensor(func_ones_tensor(normShape(shape).wrapped,
569- get_cpp_device(device))))
558+ def selu (Node x , float a = 1.6732632423543772848170429916717 , float s = 1.0507009873554804934193349852946 ):
559+ return s * tensor_functions.elu(x, a);
570560
571561 @staticmethod
572- def identity (unsigned size , Device device = None ):
573- if device is None :
574- device = Device.get_default()
575- return Tensor.get_wrapper_with_new(new CppTensor(func_identity_tensor(size, get_cpp_device(device))))
562+ def sum_list (list xs ):
563+ if not xs:
564+ raise TypeError (" No nodes to sum." )
565+ ret = xs[0 ]
566+ for x in xs[1 :]:
567+ ret = ret + x
568+ return ret
569+
570+ @staticmethod
571+ def mean (x , dim = None ):
572+ if isinstance (x, list ):
573+ return tensor_functions.sum_list(x) / len (x)
574+ else :
575+ return tensor_functions.sum(x, dim) / x.shape()[dim]
576+
577+ @staticmethod
578+ def zeros (shape , Device dev = None ):
579+ return tensor_functions.constant(shape, 0.0 , dev)
580+
581+ @staticmethod
582+ def ones (shape , Device dev = None ):
583+ return tensor_functions.constant(shape, 1.0 , dev)
584+
585+ @staticmethod
586+ def dropout (Node x , float rate , bool enabled ):
587+ if not enabled:
588+ return x
589+ if rate == 1.0 :
590+ return 0.0 * x
591+ p = 1.0 - rate
592+ return (1.0 / p) * x * tensor_functions.random.bernoulli(x.shape(), p, x.device())
593+
594+ # end contrib functions
576595
577596 class batch :
578597 @staticmethod
579598 def sum (Tensor x ):
580599 return Tensor.get_wrapper_with_new(new CppTensor(func_batch_sum[CppTensor](x.wrapped[0 ])))
581600
601+ # contrib functions
602+
582603 @staticmethod
583- def mean (Tensor x ):
584- return Tensor.get_wrapper_with_new(new CppTensor(func_batch_mean[CppTensor](x.wrapped[ 0 ])) )
604+ def mean (Node x ):
605+ return tensor_functions.batch.sum(x) / x.shape().batch( )
585606
586607 @staticmethod
587- def normalize (Tensor x ):
588- return Tensor.get_wrapper_with_new(new CppTensor(func_batch_normalize[CppTensor](x.wrapped[0 ])))
608+ def normalize (Node x ):
609+ if not x.shape().has_batch():
610+ return x
611+ b = x.shape().batch()
612+ scale = b / (b - 1 )
613+ m = tensor_functions.batch.mean(x)
614+ v = scale * (tensor_functions.batch.mean(x * x) - m * m)
615+ return (x - m) / tensor_functions.sqrt(v + 1e-8 )
616+
617+ # end contrib functions
589618
590619 class random :
591620 @staticmethod
@@ -622,7 +651,3 @@ class tensor_functions:
622651 device = Device.get_default()
623652 return Tensor.get_wrapper_with_new(new CppTensor(func_random_gumbel_tensor(normShape(shape).wrapped, mu, beta,
624653 get_cpp_device(device))))
625-
626- @staticmethod
627- def dropout (Tensor x , float rate , bool enabled ):
628- return Tensor.get_wrapper_with_new(new CppTensor(func_dropout(x.wrapped[0 ], rate, enabled)))
0 commit comments