Skip to content

Commit 692e87c

Browse files
committed
Implement contrib functions in Python
1 parent aebd0fb commit 692e87c

18 files changed

+127
-115
lines changed

primitiv-core

Submodule primitiv-core updated 331 files

primitiv/_device.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cdef extern from "primitiv/device.h":
1+
cdef extern from "primitiv/core/device.h":
22
cdef cppclass CppDevice "primitiv::Device":
33
void dump_description() except +
44

primitiv/_function.pxd

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ from primitiv._shape cimport CppShape
99
from primitiv._parameter cimport CppParameter
1010

1111

12-
cdef extern from "primitiv/functions.h":
12+
cdef extern from "primitiv/core/basic_functions.h":
1313
CppTensor func_input_tensor "primitiv::functions::input_tensor" (const CppShape &shape, const vector[float] &data, CppDevice *dev) except +
1414
CppNode func_input_node "primitiv::functions::input_node" (const CppShape &shape, const vector[float] &data, CppDevice *dev, CppGraph *g) except +
1515
CppTensor func_parameter_tensor "primitiv::functions::parameter_tensor" (CppParameter &param) except +
@@ -40,12 +40,7 @@ cdef extern from "primitiv/functions.h":
4040
Var func_prelu "primitiv::functions::prelu" [Var](const Var &x, float a) except +
4141
Var func_elu "primitiv::functions::elu" [Var](const Var &x, float a) except +
4242
Var func_selu "primitiv::functions::selu" [Var](const Var &x, float a, float s) except +
43-
CppNode func_sum "primitiv::functions::sum" (const vector[CppNode] &xs) except +
44-
CppTensor func_sum "primitiv::functions::sum" (const vector[CppTensor] &xs) except +
4543
Var func_sum "primitiv::functions::sum" [Var](const Var &x, unsigned dim) except +
46-
CppNode func_mean "primitiv::functions::mean" (const vector[CppNode] &xs) except +
47-
CppTensor func_mean "primitiv::functions::mean" (const vector[CppTensor] &xs) except +
48-
Var func_mean "primitiv::functions::mean" [Var](const Var &x, unsigned dim) except +
4944
Var func_broadcast "primitiv::functions::broadcast" [Var](const Var &x, unsigned dim, unsigned size) except +
5045
Var func_logsumexp "primitiv::functions::logsumexp" [Var](const Var &x, unsigned dim) except +
5146
Var func_log_softmax "primitiv::functions::log_softmax" [Var](const Var &x, unsigned dim) except +
@@ -58,13 +53,8 @@ cdef extern from "primitiv/functions.h":
5853

5954
CppTensor func_constant_tensor "primitiv::functions::constant_tensor" (const CppShape &shape, float k, CppDevice *dev) except +
6055
CppNode func_constant_node "primitiv::functions::constant_node" (const CppShape &shape, float k, CppDevice *dev, CppGraph *g) except +
61-
CppTensor func_zeros_tensor "primitiv::functions::zeros_tensor" (const CppShape &shape, CppDevice *dev) except +
62-
CppNode func_zeros_node "primitiv::functions::zeros_node" (const CppShape &shape, CppDevice *dev, CppGraph *g) except +
63-
CppTensor func_ones_tensor "primitiv::functions::ones_tensor" (const CppShape &shape, CppDevice *dev) except +
64-
CppNode func_ones_node "primitiv::functions::ones_node" (const CppShape &shape, CppDevice *dev, CppGraph *g) except +
6556
CppTensor func_identity_tensor "primitiv::functions::identity_tensor" (unsigned size, CppDevice *dev) except +
6657
CppNode func_identity_node "primitiv::functions::identity_node" (unsigned size, CppDevice *dev, CppGraph *g) except +
67-
Var func_dropout "primitiv::functions::dropout" [Var](const Var &x, float rate, bool enabled) except +
6858

6959
Var func_positive "primitiv::functions::positive" [Var](const Var &x) except +
7060
Var func_negative "primitiv::functions::negative" [Var](const Var &x) except +
@@ -82,13 +72,11 @@ cdef extern from "primitiv/functions.h":
8272
Var func_divide "primitiv::functions::divide" [Var](const Var &a, const Var &b) except +
8373

8474

85-
cdef extern from "primitiv/functions.h":
75+
cdef extern from "primitiv/core/basic_functions.h":
8676
Var func_batch_sum "primitiv::functions::batch::sum" [Var](const Var &x) except +
87-
Var func_batch_mean "primitiv::functions::batch::mean" [Var](const Var &x) except +
88-
Var func_batch_normalize "primitiv::functions::batch::normalize" [Var](const Var &x) except +
8977

9078

91-
cdef extern from "primitiv/functions.h":
79+
cdef extern from "primitiv/core/basic_functions.h":
9280

9381
CppNode func_random_bernoulli_node "primitiv::functions::random::bernoulli_node" (const CppShape &shape, float p, CppDevice *dev, CppGraph *g) except +
9482
CppTensor func_random_bernoulli_tensor "primitiv::functions::random::bernoulli_tensor" (const CppShape &shape, float p, CppDevice *dev) except +

primitiv/_function.pyx

Lines changed: 107 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -158,32 +158,13 @@ class functions:
158158
def elu(Node x, float a):
159159
return wrapNode(func_elu(x.wrapped, a))
160160

161-
@staticmethod
162-
def selu(Node x, float a, float s):
163-
return wrapNode(func_selu(x.wrapped, a, s))
164-
165161
@staticmethod
166162
def sum(x, dim = None):
167-
cdef vector[CppNode] xs
168-
cdef Node node
169163
if isinstance(x, list):
170-
for node in x:
171-
xs.push_back(node.wrapped)
172-
return wrapNode(func_sum(xs))
164+
return functions.sum_list(x)
173165
else:
174166
return wrapNode(func_sum((<Node> x).wrapped, <unsigned> dim))
175167

176-
@staticmethod
177-
def mean(x, dim = None):
178-
cdef vector[CppNode] xs
179-
cdef Node node
180-
if isinstance(x, list):
181-
for node in x:
182-
xs.push_back(node.wrapped)
183-
return wrapNode(func_mean(xs))
184-
else:
185-
return wrapNode(func_mean((<Node> x).wrapped, <unsigned> dim))
186-
187168
@staticmethod
188169
def broadcast(Node x, unsigned dim, unsigned size):
189170
return wrapNode(func_broadcast(x.wrapped, dim, size))
@@ -243,43 +224,76 @@ class functions:
243224
get_cpp_device(device), get_cpp_graph(graph)))
244225

245226
@staticmethod
246-
def zeros(shape, Device device = None, Graph graph = None):
227+
def identity(unsigned size, Device device = None, Graph graph = None):
247228
if device is None:
248229
device = Device.get_default()
249230
if graph is None:
250231
graph = Graph.get_default()
251-
return wrapNode(func_zeros_node(normShape(shape).wrapped,
252-
get_cpp_device(device), get_cpp_graph(graph)))
232+
return wrapNode(func_identity_node(size, get_cpp_device(device), get_cpp_graph(graph)))
233+
234+
# contrib functions
253235

254236
@staticmethod
255-
def ones(shape, Device device = None, Graph graph = None):
256-
if device is None:
257-
device = Device.get_default()
258-
if graph is None:
259-
graph = Graph.get_default()
260-
return wrapNode(func_ones_node(normShape(shape).wrapped,
261-
get_cpp_device(device), get_cpp_graph(graph)))
237+
def selu(Node x, float a=1.6732632423543772848170429916717, float s=1.0507009873554804934193349852946):
238+
return s * functions.elu(x, a);
262239

263240
@staticmethod
264-
def identity(unsigned size, Device device = None, Graph graph = None):
265-
if device is None:
266-
device = Device.get_default()
267-
if graph is None:
268-
graph = Graph.get_default()
269-
return wrapNode(func_identity_node(size, get_cpp_device(device), get_cpp_graph(graph)))
241+
def sum_list(list xs):
242+
if not xs:
243+
raise TypeError("No nodes to sum.")
244+
ret = xs[0]
245+
for x in xs[1:]:
246+
ret = ret + x
247+
return ret
248+
249+
@staticmethod
250+
def mean(x, dim = None):
251+
if isinstance(x, list):
252+
return functions.sum_list(x) / len(x)
253+
else:
254+
return functions.sum(x, dim) / x.shape()[dim]
255+
256+
@staticmethod
257+
def zeros(shape, Device dev = None, Graph g = None):
258+
return functions.constant(shape, 0.0, dev, g)
259+
260+
@staticmethod
261+
def ones(shape, Device dev = None, Graph g = None):
262+
return functions.constant(shape, 1.0, dev, g)
263+
264+
@staticmethod
265+
def dropout(Node x, float rate, bool enabled):
266+
if not enabled:
267+
return x
268+
if rate == 1.0:
269+
return 0.0 * x
270+
p = 1.0 - rate
271+
return (1.0 / p) * x * functions.random.bernoulli(x.shape(), p, x.device())
272+
273+
# end contrib functions
270274

271275
class batch:
272276
@staticmethod
273277
def sum(Node x):
274278
return wrapNode(func_batch_sum[CppNode](x.wrapped))
275279

280+
# contrib functions
281+
276282
@staticmethod
277283
def mean(Node x):
278-
return wrapNode(func_batch_mean[CppNode](x.wrapped))
284+
return functions.batch.sum(x) / x.shape().batch()
279285

280286
@staticmethod
281287
def normalize(Node x):
282-
return wrapNode(func_batch_normalize[CppNode](x.wrapped))
288+
if not x.shape().has_batch():
289+
return x
290+
b = x.shape().batch()
291+
scale = b / (b - 1)
292+
m = functions.batch.mean(x)
293+
v = scale * (functions.batch.mean(x * x) - m * m)
294+
return (x - m) / functions.sqrt(v + 1e-8)
295+
296+
# end contrib functions
283297

284298
class random:
285299
@staticmethod
@@ -327,10 +341,6 @@ class functions:
327341
return wrapNode(func_random_gumbel_node(normShape(shape).wrapped, mu, beta,
328342
get_cpp_device(device), get_cpp_graph(graph)))
329343

330-
@staticmethod
331-
def dropout(Node x, float rate, bool enabled):
332-
return wrapNode(func_dropout(x.wrapped, rate, enabled))
333-
334344

335345
class tensor_functions:
336346

@@ -472,32 +482,13 @@ class tensor_functions:
472482
def elu(Tensor x, float a):
473483
return Tensor.get_wrapper_with_new(new CppTensor(func_elu(x.wrapped[0], a)))
474484

475-
@staticmethod
476-
def selu(Tensor x, float a, float s):
477-
return Tensor.get_wrapper_with_new(new CppTensor(func_selu(x.wrapped[0], a, s)))
478-
479485
@staticmethod
480486
def sum(x, dim = None):
481-
cdef vector[CppTensor] xs
482-
cdef Tensor t
483487
if isinstance(x, list):
484-
for t in x:
485-
xs.push_back(t.wrapped[0])
486-
return Tensor.get_wrapper_with_new(new CppTensor(func_sum(xs)))
488+
return tensor_functions.sum_list(x)
487489
else:
488490
return Tensor.get_wrapper_with_new(new CppTensor(func_sum((<Tensor> x).wrapped[0], <unsigned> dim)))
489491

490-
@staticmethod
491-
def mean(x, dim = None):
492-
cdef vector[CppTensor] xs
493-
cdef Tensor t
494-
if isinstance(x, list):
495-
for t in x:
496-
xs.push_back(t.wrapped[0])
497-
return Tensor.get_wrapper_with_new(new CppTensor(func_mean(xs)))
498-
else:
499-
return Tensor.get_wrapper_with_new(new CppTensor(func_mean((<Tensor> x).wrapped[0], <unsigned> dim)))
500-
501492
@staticmethod
502493
def broadcast(Tensor x, unsigned dim, unsigned size):
503494
return Tensor.get_wrapper_with_new(new CppTensor(func_broadcast(x.wrapped[0], dim, size)))
@@ -554,38 +545,76 @@ class tensor_functions:
554545
return Tensor.get_wrapper_with_new(new CppTensor(func_constant_tensor(normShape(shape).wrapped, k,
555546
get_cpp_device(device))))
556547

548+
557549
@staticmethod
558-
def zeros(shape, Device device = None):
550+
def identity(unsigned size, Device device = None):
559551
if device is None:
560552
device = Device.get_default()
561-
return Tensor.get_wrapper_with_new(new CppTensor(func_zeros_tensor(normShape(shape).wrapped,
562-
get_cpp_device(device))))
553+
return Tensor.get_wrapper_with_new(new CppTensor(func_identity_tensor(size, get_cpp_device(device))))
554+
555+
# contrib functions
563556

564557
@staticmethod
565-
def ones(shape, Device device = None):
566-
if device is None:
567-
device = Device.get_default()
568-
return Tensor.get_wrapper_with_new(new CppTensor(func_ones_tensor(normShape(shape).wrapped,
569-
get_cpp_device(device))))
558+
def selu(Node x, float a=1.6732632423543772848170429916717, float s=1.0507009873554804934193349852946):
559+
return s * tensor_functions.elu(x, a);
570560

571561
@staticmethod
572-
def identity(unsigned size, Device device = None):
573-
if device is None:
574-
device = Device.get_default()
575-
return Tensor.get_wrapper_with_new(new CppTensor(func_identity_tensor(size, get_cpp_device(device))))
562+
def sum_list(list xs):
563+
if not xs:
564+
raise TypeError("No nodes to sum.")
565+
ret = xs[0]
566+
for x in xs[1:]:
567+
ret = ret + x
568+
return ret
569+
570+
@staticmethod
571+
def mean(x, dim = None):
572+
if isinstance(x, list):
573+
return tensor_functions.sum_list(x) / len(x)
574+
else:
575+
return tensor_functions.sum(x, dim) / x.shape()[dim]
576+
577+
@staticmethod
578+
def zeros(shape, Device dev = None):
579+
return tensor_functions.constant(shape, 0.0, dev)
580+
581+
@staticmethod
582+
def ones(shape, Device dev = None):
583+
return tensor_functions.constant(shape, 1.0, dev)
584+
585+
@staticmethod
586+
def dropout(Node x, float rate, bool enabled):
587+
if not enabled:
588+
return x
589+
if rate == 1.0:
590+
return 0.0 * x
591+
p = 1.0 - rate
592+
return (1.0 / p) * x * tensor_functions.random.bernoulli(x.shape(), p, x.device())
593+
594+
# end contrib functions
576595

577596
class batch:
578597
@staticmethod
579598
def sum(Tensor x):
580599
return Tensor.get_wrapper_with_new(new CppTensor(func_batch_sum[CppTensor](x.wrapped[0])))
581600

601+
# contrib functions
602+
582603
@staticmethod
583-
def mean(Tensor x):
584-
return Tensor.get_wrapper_with_new(new CppTensor(func_batch_mean[CppTensor](x.wrapped[0])))
604+
def mean(Node x):
605+
return tensor_functions.batch.sum(x) / x.shape().batch()
585606

586607
@staticmethod
587-
def normalize(Tensor x):
588-
return Tensor.get_wrapper_with_new(new CppTensor(func_batch_normalize[CppTensor](x.wrapped[0])))
608+
def normalize(Node x):
609+
if not x.shape().has_batch():
610+
return x
611+
b = x.shape().batch()
612+
scale = b / (b - 1)
613+
m = tensor_functions.batch.mean(x)
614+
v = scale * (tensor_functions.batch.mean(x * x) - m * m)
615+
return (x - m) / tensor_functions.sqrt(v + 1e-8)
616+
617+
# end contrib functions
589618

590619
class random:
591620
@staticmethod
@@ -622,7 +651,3 @@ class tensor_functions:
622651
device = Device.get_default()
623652
return Tensor.get_wrapper_with_new(new CppTensor(func_random_gumbel_tensor(normShape(shape).wrapped, mu, beta,
624653
get_cpp_device(device))))
625-
626-
@staticmethod
627-
def dropout(Tensor x, float rate, bool enabled):
628-
return Tensor.get_wrapper_with_new(new CppTensor(func_dropout(x.wrapped[0], rate, enabled)))

primitiv/_graph.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ from primitiv._shape cimport CppShape
66
from primitiv._tensor cimport CppTensor
77

88

9-
cdef extern from "primitiv/graph.h" nogil:
9+
cdef extern from "primitiv/core/graph.h" nogil:
1010
cdef cppclass CppNode "primitiv::Node":
1111
CppNode(CppNode &&src) except +
1212
CppNode() except +
@@ -23,7 +23,7 @@ cdef extern from "primitiv/graph.h" nogil:
2323
void backward() except +
2424

2525

26-
cdef extern from "primitiv/graph.h" nogil:
26+
cdef extern from "primitiv/core/graph.h" nogil:
2727
cdef cppclass CppGraph "primitiv::Graph":
2828
CppGraph() except +
2929
void clear() except +

primitiv/_initializer.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from primitiv._tensor cimport CppTensor
22

33

4-
cdef extern from "primitiv/initializer.h":
4+
cdef extern from "primitiv/core/initializer.h":
55
cdef cppclass CppInitializer "primitiv::Initializer":
66
CppInitializer() except +
77
void apply(CppTensor &x) except +

primitiv/_model.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ from primitiv._device cimport CppDevice
77
from primitiv._parameter cimport CppParameter
88

99

10-
cdef extern from "primitiv/model.h":
10+
cdef extern from "primitiv/core/model.h":
1111
cdef cppclass CppModel "primitiv::Model":
1212
CppModel() except +
1313
void load(string &path, bool with_stats, CppDevice *device) except +

primitiv/_optimizer.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ from primitiv._parameter cimport CppParameter, Parameter
1010
from primitiv._shape cimport CppShape
1111

1212

13-
cdef extern from "primitiv/optimizer.h":
13+
cdef extern from "primitiv/core/optimizer.h":
1414
cdef cppclass CppOptimizer "primitiv::Optimizer":
1515
CppOptimizer(CppOptimizer &&) except +
1616
CppOptimizer() except +

primitiv/_parameter.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ from primitiv._device cimport CppDevice
99
from primitiv._initializer cimport CppInitializer, Initializer
1010

1111

12-
cdef extern from "primitiv/parameter.h":
12+
cdef extern from "primitiv/core/parameter.h":
1313
cdef cppclass CppParameter "primitiv::Parameter":
1414
CppParameter() except +
1515
CppParameter(const CppShape &shape, const vector[float] &value, CppDevice *device) except +

primitiv/_shape.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from libcpp.string cimport string
33
from libcpp cimport bool
44

55

6-
cdef extern from "primitiv/shape.h":
6+
cdef extern from "primitiv/core/shape.h":
77
cdef cppclass CppShape "primitiv::Shape":
88
CppShape() except +
99
CppShape(vector[unsigned] &dims, unsigned batch) except +

0 commit comments

Comments
 (0)