Skip to content

Commit 307e88f

Browse files
committed
Fix typos: Change 'arugments' to 'arguments'.
1 parent 0a755ae commit 307e88f

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

docs/Custom_Operation_for_GPUs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ class RmsNormFwdClass:
679679
NamedSharding(mesh, PartitionSpec(None, None)))
680680
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
681681
output_shardings = (arg_shardings[0], invvar_sharding)
682-
# Sharded_impl only accepts positional arugments
682+
# Sharded_impl only accepts positional arguments
683683
# And they should be Jax traceable variables
684684
impl = partial(RmsNormFwdClass.impl, eps=eps)
685685
@@ -739,7 +739,7 @@ class RmsNormBwdClass:
739739
output_shardings = (output_sharding, invvar_sharding, invvar_sharding)
740740
741741
742-
# Sharded_impl only accepts positional arugments
742+
# Sharded_impl only accepts positional arguments
743743
# And they should be Jax traceable variables
744744
def impl(g, invvar, x, weight):
745745
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(

docs/Custom_Operation_for_GPUs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def partition(eps: float, mesh : jax.sharding.Mesh,
353353
NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything.
354354
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
355355
output_shardings = (arg_shardings[0], invvar_sharding)
356-
# Sharded_impl only accepts positional arugments
356+
# Sharded_impl only accepts positional arguments
357357
# And they should be Jax traceable variables
358358
impl = partial(RmsNormFwdClass.impl, eps=eps)
359359

jax/_src/callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def pure_callback(
343343
* Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method``
344344
is deprecated and it will eventually raise ``NotImplementedError``.
345345
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
346-
the batched arugments, calling ``callback`` once for each batch element.
346+
the batched arguments, calling ``callback`` once for each batch element.
347347
* ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1``
348348
added as the leading dimension unbatched inputs.
349349
* ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the

0 commit comments

Comments
 (0)