Skip to content

Commit 0e516cd

Browse files
committed
added support for *_like, and constructors for gradtensors
1 parent 0c9e1bd commit 0e516cd

File tree

6 files changed

+193
-3
lines changed

6 files changed

+193
-3
lines changed

aten/bindings/Tensor/gradtensor_binding.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,50 @@ void init_gradtensor_binding(py::module_ &m) {
1717
return new GradTensor(shape, bidx, pidx);
1818
}))
1919
.def_static("eye", &GradTensor::eye,
20-
py::arg("n"), py::arg("bidx"), py::arg("pidx"))
20+
py::arg("n"),
21+
py::arg("bidx"),
22+
py::arg("pidx")
23+
)
24+
.def_static("gaussian", &GradTensor::gaussian,
25+
py::arg("shape") = std::vector<size_t>{1, 1},
26+
py::arg("mean") = 0.0,
27+
py::arg("stddev") = 1.0,
28+
py::arg("bidx") = 0,
29+
py::arg("pidx") = 1
30+
)
31+
.def_static("gaussian_like", &GradTensor::gaussian_like,
32+
py::arg("input"),
33+
py::arg("mean") = 0.0,
34+
py::arg("stddev") = 1.0
35+
)
36+
.def_static("uniform", &GradTensor::uniform,
37+
py::arg("shape") = std::vector<size_t>{1, 1},
38+
py::arg("min") = 0.0,
39+
py::arg("max") = 1.0,
40+
py::arg("bidx") = 0,
41+
py::arg("pidx") = 1
42+
)
43+
.def_static("uniform_like", &GradTensor::uniform_like,
44+
py::arg("input"),
45+
py::arg("min") = 0.0,
46+
py::arg("max") = 1.0
47+
)
48+
.def_static("ones", &GradTensor::ones,
49+
py::arg("shape") = std::vector<size_t>{1, 1},
50+
py::arg("bidx") = 0,
51+
py::arg("pidx") = 1
52+
)
53+
.def_static("ones_like", &GradTensor::ones_like,
54+
py::arg("shape")
55+
)
56+
.def_static("zeros", &GradTensor::zeros,
57+
py::arg("shape") = std::vector<size_t>{1, 1},
58+
py::arg("bidx") = 0,
59+
py::arg("pidx") = 1
60+
)
61+
.def_static("zeros_like", &GradTensor::zeros_like,
62+
py::arg("shape")
63+
)
2164

2265
// string
2366
.def("__repr__", &GradTensor::operator std::string, py::is_operator())

aten/bindings/Tensor/tensor_binding.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,39 @@ void init_tensor_binding(py::module_ &m) {
7171
py::arg("bidx") = 0,
7272
py::arg("requires_grad") = true
7373
)
74+
.def_static("gaussian_like", &Tensor::gaussian_like,
75+
py::arg("input"),
76+
py::arg("mean") = 0.0,
77+
py::arg("stddev") = 1.0
78+
)
7479
.def_static("uniform", &Tensor::uniform,
7580
py::arg("shape") = std::vector<size_t>{1},
7681
py::arg("min") = 0.0,
7782
py::arg("max") = 1.0,
7883
py::arg("bidx") = 0,
7984
py::arg("requires_grad") = true
8085
)
86+
.def_static("uniform_like", &Tensor::uniform_like,
87+
py::arg("input"),
88+
py::arg("min") = 0.0,
89+
py::arg("max") = 1.0
90+
)
8191
.def_static("ones", &Tensor::ones,
8292
py::arg("shape") = std::vector<size_t>{1},
8393
py::arg("bidx") = 0,
8494
py::arg("requires_grad") = true
8595
)
96+
.def_static("ones_like", &Tensor::ones_like,
97+
py::arg("input")
98+
)
8699
.def_static("zeros", &Tensor::zeros,
87100
py::arg("shape") = std::vector<size_t>{1},
88101
py::arg("bidx") = 0,
89102
py::arg("requires_grad") = true
90103
)
104+
.def_static("zeros_like", &Tensor::zeros_like,
105+
py::arg("input")
106+
)
91107

92108
// string
93109
.def("__repr__", &Tensor::operator std::string, py::is_operator())

aten/src/Tensor/GradTensor/constructor.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include <vector>
2+
#include <random>
23
#include <ctime>
4+
#include <chrono>
5+
#include <atomic>
6+
#include <numeric>
37
#include "../Tensor.h"
48
#include "../../Util/utils.h"
59

@@ -48,3 +52,72 @@ GradTensor* GradTensor::eye(size_t n, size_t bidx, size_t pidx) {
4852
return new GradTensor(storage, shape, bidx, pidx);
4953
}
5054

55+
GradTensor* GradTensor::gaussian(std::vector<size_t> shape, double mean, double stddev, size_t bidx, size_t pidx) {
56+
// Create a unique seed by combining high-resolution time and a counter
57+
static std::atomic<unsigned long long> seed_counter{0};
58+
59+
auto now = std::chrono::high_resolution_clock::now();
60+
auto nanos = std::chrono::duration_cast<std::chrono::nanoseconds>(now.time_since_epoch()).count();
61+
unsigned long long unique_seed = nanos ^ (seed_counter.fetch_add(1, std::memory_order_relaxed) << 32);
62+
63+
// Create a generator with the unique seed
64+
std::mt19937 generator(unique_seed);
65+
66+
// Create a distribution
67+
std::normal_distribution<double> distribution(mean, stddev);
68+
69+
// Calculate total number of elements
70+
int length = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
71+
72+
// Create and fill the vector
73+
std::vector<double> result(length);
74+
for (int i = 0; i < length; ++i) {
75+
result[i] = distribution(generator);
76+
}
77+
78+
return new GradTensor(result, shape, bidx, pidx);
79+
}
80+
81+
GradTensor* GradTensor::gaussian_like(GradTensor* input, double mean, double stddev) {
82+
return GradTensor::gaussian(input->shape(), mean, stddev, input->bidx, input->pidx());
83+
}
84+
85+
GradTensor* GradTensor::uniform(std::vector<size_t> shape, double min, double max, size_t bidx, size_t pidx) {
86+
// (Use the same unique seeding method as in the gaussian function)
87+
static std::atomic<unsigned long long> seed_counter{0};
88+
89+
auto now = std::chrono::high_resolution_clock::now();
90+
auto nanos = std::chrono::duration_cast<std::chrono::nanoseconds>(now.time_since_epoch()).count();
91+
unsigned long long unique_seed = nanos ^ (seed_counter.fetch_add(1, std::memory_order_relaxed) << 32);
92+
93+
std::mt19937 generator(unique_seed);
94+
std::uniform_real_distribution<double> distribution(min, max);
95+
96+
int length = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
97+
std::vector<double> result(length);
98+
for (int i = 0; i < length; ++i) {
99+
result[i] = distribution(generator);
100+
}
101+
102+
return new GradTensor(result, shape, bidx, pidx);
103+
}
104+
105+
GradTensor* GradTensor::uniform_like(GradTensor* input, double min, double max) {
106+
return GradTensor::uniform(input->shape(), min, max, input->bidx, input->pidx());
107+
}
108+
109+
GradTensor* GradTensor::ones(std::vector<size_t> shape, size_t bidx, size_t pidx) {
110+
return new GradTensor(std::vector<double> (CIntegrity::prod(shape), 1.0), shape, bidx, pidx);
111+
}
112+
113+
GradTensor* GradTensor::ones_like(GradTensor* input) {
114+
return GradTensor::ones(input->shape(), input->bidx, input->pidx());
115+
}
116+
117+
GradTensor* GradTensor::zeros(std::vector<size_t> shape, size_t bidx, size_t pidx) {
118+
return new GradTensor(std::vector<double> (CIntegrity::prod(shape), 0.0), shape, bidx, pidx);
119+
}
120+
121+
GradTensor* GradTensor::zeros_like(GradTensor* input) {
122+
return GradTensor::zeros(input->shape(), input->bidx, input->pidx());
123+
}

aten/src/Tensor/Tensor.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,15 @@ class GradTensor : public BaseTensor {
9494
GradTensor(std::vector<double> storage, std::vector<size_t> shape, size_t bidx, size_t pidx);
9595
GradTensor(std::vector<size_t> shape, size_t bidx, size_t pidx);
9696
static GradTensor* eye(size_t n, size_t bidx, size_t pidx);
97-
// zeros, ones, zeros_like, ones_like, uninitialized (requires not vector but array), random ones
97+
// uninitialized (requires not vector but array)
98+
static GradTensor* gaussian(std::vector<size_t> shape, double mean, double stddev, size_t bidx, size_t pidx);
99+
static GradTensor* gaussian_like(GradTensor* input, double mean, double stddev);
100+
static GradTensor* uniform(std::vector<size_t> shape, double min, double max, size_t bidx, size_t pidx);
101+
static GradTensor* uniform_like(GradTensor* input, double min, double max);
102+
static GradTensor* ones(std::vector<size_t> shape, size_t bidx, size_t pidx);
103+
static GradTensor* ones_like(GradTensor* input);
104+
static GradTensor* zeros(std::vector<size_t> shape, size_t bidx, size_t pidx);
105+
static GradTensor* zeros_like(GradTensor* input);
98106

99107
// string.cpp
100108
operator std::string() const override;
@@ -168,9 +176,13 @@ class Tensor : public BaseTensor {
168176
static Tensor* arange(int start, int stop, int step = 1, bool requires_grad = true);
169177
static Tensor* linspace(double start, double stop, int numsteps, bool requires_grad = true);
170178
static Tensor* gaussian(std::vector<size_t> shape, double mean = 0.0, double stddev = 1.0, size_t bidx = 0, bool requires_grad = true);
179+
static Tensor* gaussian_like(Tensor* input, double mean, double stddev);
171180
static Tensor* uniform(std::vector<size_t> shape, double min = 0.0, double max = 1.0, size_t bidx = 0, bool requires_grad = true);
181+
static Tensor* uniform_like(Tensor* input, double min, double max);
172182
static Tensor* ones(std::vector<size_t> shape, size_t bidx = 0, bool requires_grad = true);
183+
static Tensor* ones_like(Tensor* input);
173184
static Tensor* zeros(std::vector<size_t> shape, size_t bidx = 0, bool requires_grad = true);
185+
static Tensor* zeros_like(Tensor* input);
174186
~Tensor() { _prev.clear(); }
175187

176188
// string.cpp

aten/src/Tensor/Tensor/constructor.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "../../Util/utils.h"
99

1010
Tensor::Tensor(double scalar, bool requires_grad) {
11-
// Scalar tensor
1211
this->_storage = std::vector<double>{scalar};
1312
this->_shape = std::vector<size_t>{1};
1413
this->bidx = 0;
@@ -126,6 +125,10 @@ Tensor* Tensor::gaussian(std::vector<size_t> shape, double mean, double stddev,
126125
return new Tensor(result, shape, bidx, requires_grad);
127126
}
128127

128+
Tensor* Tensor::gaussian_like(Tensor* input, double mean, double stddev) {
129+
return Tensor::gaussian(input->shape(), mean, stddev, input->bidx, input->requires_grad);
130+
}
131+
129132
Tensor* Tensor::uniform(std::vector<size_t> shape, double min, double max, size_t bidx, bool requires_grad) {
130133
// (Use the same unique seeding method as in the gaussian function)
131134
static std::atomic<unsigned long long> seed_counter{0};
@@ -146,11 +149,23 @@ Tensor* Tensor::uniform(std::vector<size_t> shape, double min, double max, size_
146149
return new Tensor(result, shape, bidx, requires_grad);
147150
}
148151

152+
Tensor* Tensor::uniform_like(Tensor* input, double min, double max) {
153+
return Tensor::uniform(input->shape(), min, max, input->bidx, input->requires_grad);
154+
}
155+
149156
Tensor* Tensor::ones(std::vector<size_t> shape, size_t bidx, bool requires_grad) {
150157
return new Tensor(std::vector<double> (CIntegrity::prod(shape), 1.0), shape, bidx, requires_grad);
151158
}
152159

160+
Tensor* Tensor::ones_like(Tensor* input) {
161+
return Tensor::ones(input->shape(), input->bidx, input->requires_grad);
162+
}
163+
153164
Tensor* Tensor::zeros(std::vector<size_t> shape, size_t bidx, bool requires_grad) {
154165
return new Tensor(std::vector<double>(CIntegrity::prod(shape), 0.0), shape, bidx, requires_grad);
155166
}
156167

168+
Tensor* Tensor::zeros_like(Tensor* input) {
169+
return Tensor::zeros(input->shape(), input->bidx, input->requires_grad);
170+
}
171+

ember/aten/__init__.pyi

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,37 @@ class GradTensor(BaseTensor):
5151
def __init__(self, shape: List[int], bidx: int, pidx: int) -> None: ...
5252
@staticmethod
5353
def eye(n: int, pidx: int = 1) -> 'GradTensor': ...
54+
@staticmethod
55+
def gaussian(
56+
shape: List[int] = [1, 1],
57+
mean: float = 0.0,
58+
stddev: float = 1.0,
59+
bidx: int = 0,
60+
pidx: int = 1
61+
) -> 'GradTensor': ...
62+
63+
@staticmethod
64+
def uniform(
65+
shape: List[int] = [1, 1],
66+
min: float = 0.0,
67+
max: float = 1.0,
68+
bidx: int = 0,
69+
pidx: int = 1
70+
) -> 'GradTensor': ...
71+
72+
@staticmethod
73+
def ones(
74+
shape: List[int] = [1, 1],
75+
bidx: int = 0,
76+
pidx: int = 1
77+
) -> 'GradTensor': ...
78+
79+
@staticmethod
80+
def zeros(
81+
shape: List[int] = [1, 1],
82+
bidx: int = 0,
83+
pidx: int = 1
84+
) -> 'GradTensor': ...
5485

5586
# string
5687
def __str__(self) -> str: ...

0 commit comments

Comments
 (0)