Skip to content

Commit 6cf0ea3

Browse files
committed
added stub files and bindings
1 parent 0e516cd commit 6cf0ea3

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

aten/bindings/Tensor/gradtensor_binding.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ void init_gradtensor_binding(py::module_ &m) {
5151
py::arg("pidx") = 1
5252
)
5353
.def_static("ones_like", &GradTensor::ones_like,
54-
py::arg("shape")
54+
py::arg("input")
5555
)
5656
.def_static("zeros", &GradTensor::zeros,
5757
py::arg("shape") = std::vector<size_t>{1, 1},
5858
py::arg("bidx") = 0,
5959
py::arg("pidx") = 1
6060
)
6161
.def_static("zeros_like", &GradTensor::zeros_like,
62-
py::arg("shape")
62+
py::arg("input")
6363
)
6464

6565
// string

ember/aten/__init__.pyi

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ class GradTensor(BaseTensor):
5959
bidx: int = 0,
6060
pidx: int = 1
6161
) -> 'GradTensor': ...
62-
62+
@staticmethod
63+
def gaussian_like(
64+
input: 'GradTensor',
65+
mean: float = 0.0,
66+
stddev: float = 1.0
67+
) -> 'GradTensor': ...
6368
@staticmethod
6469
def uniform(
6570
shape: List[int] = [1, 1],
@@ -68,20 +73,32 @@ class GradTensor(BaseTensor):
6873
bidx: int = 0,
6974
pidx: int = 1
7075
) -> 'GradTensor': ...
71-
76+
@staticmethod
77+
def uniform_like(
78+
input: 'GradTensor',
79+
min: float = 0.0,
80+
max: float = 1.0
81+
) -> 'GradTensor': ...
7282
@staticmethod
7383
def ones(
7484
shape: List[int] = [1, 1],
7585
bidx: int = 0,
7686
pidx: int = 1
7787
) -> 'GradTensor': ...
78-
88+
@staticmethod
89+
def ones_like(
90+
input: 'GradTensor'
91+
) -> 'GradTensor': ...
7992
@staticmethod
8093
def zeros(
8194
shape: List[int] = [1, 1],
8295
bidx: int = 0,
8396
pidx: int = 1
8497
) -> 'GradTensor': ...
98+
@staticmethod
99+
def zeros_like(
100+
input: 'GradTensor'
101+
) -> 'GradTensor': ...
85102

86103
# string
87104
def __str__(self) -> str: ...
@@ -210,23 +227,21 @@ class Tensor(BaseTensor):
210227
bidx: int = 0,
211228
requires_grad: bool = True
212229
) -> None: ...
213-
230+
214231
@staticmethod
215232
def arange(
216233
start: int,
217234
stop: int,
218235
step: int = 1,
219236
requires_grad: bool = True
220237
) -> 'Tensor': ...
221-
222238
@staticmethod
223239
def linspace(
224240
start: float,
225241
stop: float,
226242
numsteps: int,
227243
requires_grad: bool = True
228244
) -> 'Tensor': ...
229-
230245
@staticmethod
231246
def gaussian(
232247
shape: List[int],
@@ -235,7 +250,12 @@ class Tensor(BaseTensor):
235250
bidx: int = 0,
236251
requires_grad: bool = True
237252
) -> 'Tensor': ...
238-
253+
@staticmethod
254+
def gaussian_like(
255+
input: 'Tensor',
256+
mean: float = 0.0,
257+
stddev: float = 1.0
258+
) -> 'Tensor': ...
239259
@staticmethod
240260
def uniform(
241261
shape: List[int],
@@ -244,20 +264,32 @@ class Tensor(BaseTensor):
244264
bidx: int = 0,
245265
requires_grad: bool = True
246266
) -> 'Tensor': ...
247-
267+
@staticmethod
268+
def uniform_like(
269+
input: 'Tensor',
270+
min: float = 0.0,
271+
max: float = 1.0
272+
) -> 'Tensor': ...
248273
@staticmethod
249274
def ones(
250275
shape: List[int],
251276
bidx: int = 0,
252277
requires_grad: bool = True
253278
) -> 'Tensor': ...
254-
279+
@staticmethod
280+
def ones_like(
281+
input: 'Tensor'
282+
) -> 'Tensor': ...
255283
@staticmethod
256284
def zeros(
257285
shape: List[int],
258286
bidx: int = 0,
259287
requires_grad: bool = True
260288
) -> 'Tensor': ...
289+
@staticmethod
290+
def zeros_like(
291+
input: 'Tensor'
292+
) -> 'Tensor': ...
261293

262294
# string
263295
def __str__(self) -> str: ...

0 commit comments

Comments
 (0)