Skip to content

Commit b054838

Browse files
yuchaoran2011awni
andauthored
Added clarification to apply_fn parameter of apply_to_modules (#2831)
Co-authored-by: Awni Hannun <[email protected]>
1 parent dd79d3c commit b054838

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

python/mlx/nn/layers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,10 @@ def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
407407
instance).
408408
409409
Args:
410-
apply_fn (Callable): The function to apply to the modules.
410+
apply_fn (Callable): The function to apply to the modules which
411+
takes two parameters. The first parameter is the string path of
412+
the module (e.g. ``"model.layers.0.linear"``). The second
413+
parameter is the module object.
411414
412415
Returns:
413416
The module instance after updating submodules.

python/src/transforms.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,8 +1238,18 @@ void init_transforms(nb::module_& m) {
12381238
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
12391239
12401240
Returns:
1241-
list(array): A list of the Jacobian-vector products which
1242-
is the same in number, shape, and type of the inputs to ``fun``.
1241+
tuple(list(array), list(array)): A tuple with the outputs of
1242+
``fun`` in the first position and the Jacobian-vector products
1243+
in the second position.
1244+
1245+
Example:
1246+
1247+
.. code-block:: python
1248+
1249+
import mlx.core as mx
1250+
1251+
outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
1252+
12431253
)pbdoc");
12441254
m.def(
12451255
"vjp",
@@ -1277,8 +1287,18 @@ void init_transforms(nb::module_& m) {
12771287
same in number, shape, and type as the outputs of ``fun``.
12781288
12791289
Returns:
1280-
list(array): A list of the vector-Jacobian products which
1281-
is the same in number, shape, and type of the outputs of ``fun``.
1290+
tuple(list(array), list(array)): A tuple with the outputs of
1291+
``fun`` in the first position and the vector-Jacobian products
1292+
in the second position.
1293+
1294+
Example:
1295+
1296+
.. code-block:: python
1297+
1298+
import mlx.core as mx
1299+
1300+
outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
1301+
12821302
)pbdoc");
12831303
m.def(
12841304
"value_and_grad",

0 commit comments

Comments
 (0)