Skip to content

Commit 243836e

Browse files
authored
DOC: Fix docstrings in gradient.py (#415)
1 parent 49acbc5 commit 243836e

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

pytensor/gradient.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,13 @@ def Rop(
196196
197197
Returns
198198
-------
199+
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
199200
A symbolic expression such obeying
200201
``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
201202
where the indices in that expression are magic multidimensional
202203
indices that specify both the position within a list and all
203204
coordinates of the tensor elements.
204-
If `wrt` is a list/tuple, then return a list/tuple with the results.
205+
If `f` is a list/tuple, then return a list/tuple with the results.
205206
"""
206207

207208
if not isinstance(wrt, (list, tuple)):
@@ -384,6 +385,7 @@ def Lop(
384385
385386
Returns
386387
-------
388+
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
387389
A symbolic expression satisfying
388390
``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]``
389391
where the indices in that expression are magic multidimensional
@@ -481,10 +483,10 @@ def grad(
481483
482484
Returns
483485
-------
486+
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
484487
A symbolic expression for the gradient of `cost` with respect to each
485488
of the `wrt` terms. If an element of `wrt` is not differentiable with
486489
respect to the output, then a zero variable is returned.
487-
488490
"""
489491
t0 = time.perf_counter()
490492

@@ -701,7 +703,6 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
701703
702704
Parameters
703705
----------
704-
705706
wrt : list of variables
706707
Gradients are computed with respect to `wrt`.
707708
@@ -876,7 +877,6 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
876877
877878
(A variable in consider_constant is not a function of
878879
anything)
879-
880880
"""
881881

882882
# Validate and format consider_constant
@@ -1035,7 +1035,6 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
10351035
-------
10361036
list of Variables
10371037
A list of gradients corresponding to `wrt`
1038-
10391038
"""
10401039
# build a dict mapping node to the terms node contributes to each of
10411040
# its inputs' gradients
@@ -1423,8 +1422,9 @@ def access_grad_cache(var):
14231422

14241423

14251424
def _float_zeros_like(x):
1426-
"""Like zeros_like, but forces the object to have a
1427-
a floating point dtype"""
1425+
"""Like zeros_like, but forces the object to have
1426+
a floating point dtype
1427+
"""
14281428

14291429
rval = x.zeros_like()
14301430

@@ -1436,7 +1436,8 @@ def _float_zeros_like(x):
14361436

14371437
def _float_ones_like(x):
14381438
"""Like ones_like, but forces the object to have a
1439-
floating point dtype"""
1439+
floating point dtype
1440+
"""
14401441

14411442
dtype = x.type.dtype
14421443
if dtype not in pytensor.tensor.type.float_dtypes:
@@ -1613,7 +1614,6 @@ def abs_rel_errors(self, g_pt):
16131614
16141615
Corresponding ndarrays in `g_pt` and `self.gf` must have the same
16151616
shape or ValueError is raised.
1616-
16171617
"""
16181618
if len(g_pt) != len(self.gf):
16191619
raise ValueError("argument has wrong number of elements", len(g_pt))
@@ -1740,7 +1740,6 @@ def verify_grad(
17401740
This function does not support multiple outputs. In `tests.scan.test_basic`
17411741
there is an experimental `verify_grad` that covers that case as well by
17421742
using random projections.
1743-
17441743
"""
17451744
from pytensor.compile.function import function
17461745
from pytensor.compile.sharedvalue import shared
@@ -2267,7 +2266,6 @@ def grad_clip(x, lower_bound, upper_bound):
22672266
-----
22682267
We register an opt in tensor/opt.py that remove the GradClip.
22692268
So it have 0 cost in the forward and only do work in the grad.
2270-
22712269
"""
22722270
return GradClip(lower_bound, upper_bound)(x)
22732271

0 commit comments

Comments
 (0)