File tree Expand file tree Collapse file tree 2 files changed +43
-11
lines changed
Expand file tree Collapse file tree 2 files changed +43
-11
lines changed Original file line number Diff line number Diff line change @@ -51,15 +51,15 @@ void init_gradtensor_binding(py::module_ &m) {
5151 py::arg (" pidx" ) = 1
5252 )
5353 .def_static (" ones_like" , &GradTensor::ones_like,
54- py::arg (" shape " )
54+ py::arg (" input " )
5555 )
5656 .def_static (" zeros" , &GradTensor::zeros,
5757 py::arg (" shape" ) = std::vector<size_t >{1 , 1 },
5858 py::arg (" bidx" ) = 0 ,
5959 py::arg (" pidx" ) = 1
6060 )
6161 .def_static (" zeros_like" , &GradTensor::zeros_like,
62- py::arg (" shape " )
62+ py::arg (" input " )
6363 )
6464
6565 // string
Original file line number Diff line number Diff line change @@ -59,7 +59,12 @@ class GradTensor(BaseTensor):
5959 bidx : int = 0 ,
6060 pidx : int = 1
6161 ) -> 'GradTensor' : ...
62-
62+ @staticmethod
63+ def gaussian_like (
64+ input : 'GradTensor' ,
65+ mean : float = 0.0 ,
66+ stddev : float = 1.0
67+ ) -> 'GradTensor' : ...
6368 @staticmethod
6469 def uniform (
6570 shape : List [int ] = [1 , 1 ],
@@ -68,20 +73,32 @@ class GradTensor(BaseTensor):
6873 bidx : int = 0 ,
6974 pidx : int = 1
7075 ) -> 'GradTensor' : ...
71-
76+ @staticmethod
77+ def uniform_like (
78+ input : 'GradTensor' ,
79+ min : float = 0.0 ,
80+ max : float = 1.0
81+ ) -> 'GradTensor' : ...
7282 @staticmethod
7383 def ones (
7484 shape : List [int ] = [1 , 1 ],
7585 bidx : int = 0 ,
7686 pidx : int = 1
7787 ) -> 'GradTensor' : ...
78-
88+ @staticmethod
89+ def ones_like (
90+ input : 'GradTensor'
91+ ) -> 'GradTensor' : ...
7992 @staticmethod
8093 def zeros (
8194 shape : List [int ] = [1 , 1 ],
8295 bidx : int = 0 ,
8396 pidx : int = 1
8497 ) -> 'GradTensor' : ...
98+ @staticmethod
99+ def zeros_like (
100+ input : 'GradTensor'
101+ ) -> 'GradTensor' : ...
85102
86103 # string
87104 def __str__ (self ) -> str : ...
@@ -210,23 +227,21 @@ class Tensor(BaseTensor):
210227 bidx : int = 0 ,
211228 requires_grad : bool = True
212229 ) -> None : ...
213-
230+
214231 @staticmethod
215232 def arange (
216233 start : int ,
217234 stop : int ,
218235 step : int = 1 ,
219236 requires_grad : bool = True
220237 ) -> 'Tensor' : ...
221-
222238 @staticmethod
223239 def linspace (
224240 start : float ,
225241 stop : float ,
226242 numsteps : int ,
227243 requires_grad : bool = True
228244 ) -> 'Tensor' : ...
229-
230245 @staticmethod
231246 def gaussian (
232247 shape : List [int ],
@@ -235,7 +250,12 @@ class Tensor(BaseTensor):
235250 bidx : int = 0 ,
236251 requires_grad : bool = True
237252 ) -> 'Tensor' : ...
238-
253+ @staticmethod
254+ def gaussian_like (
255+ input : 'Tensor' ,
256+ mean : float = 0.0 ,
257+ stddev : float = 1.0
258+ ) -> 'Tensor' : ...
239259 @staticmethod
240260 def uniform (
241261 shape : List [int ],
@@ -244,20 +264,32 @@ class Tensor(BaseTensor):
244264 bidx : int = 0 ,
245265 requires_grad : bool = True
246266 ) -> 'Tensor' : ...
247-
267+ @staticmethod
268+ def uniform_like (
269+ input : 'Tensor' ,
270+ min : float = 0.0 ,
271+ max : float = 1.0
272+ ) -> 'Tensor' : ...
248273 @staticmethod
249274 def ones (
250275 shape : List [int ],
251276 bidx : int = 0 ,
252277 requires_grad : bool = True
253278 ) -> 'Tensor' : ...
254-
279+ @staticmethod
280+ def ones_like (
281+ input : 'Tensor'
282+ ) -> 'Tensor' : ...
255283 @staticmethod
256284 def zeros (
257285 shape : List [int ],
258286 bidx : int = 0 ,
259287 requires_grad : bool = True
260288 ) -> 'Tensor' : ...
289+ @staticmethod
290+ def zeros_like (
291+ input : 'Tensor'
292+ ) -> 'Tensor' : ...
261293
262294 # string
263295 def __str__ (self ) -> str : ...
You can’t perform that action at this time.
0 commit comments