Skip to content

Commit 69559f2

Browse files
author
Alexander
committed
renamed module to mpx
1 parent b18d7f4 commit 69559f2

File tree

9 files changed

+8
-8
lines changed

9 files changed

+8
-8
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333

3434
import optax
3535

36-
import mpfj.cast as cast
37-
import mpfj.loss_scaling as loss_scaling
36+
import mpx.cast as cast
37+
import mpx.loss_scaling as loss_scaling
3838

3939
from jaxtyping import PyTree, Bool
4040

tests/test_cast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jaxtyping import Array, Float, Int, PyTree
66
import numpy as np
77

8-
from mpfj.cast import (
8+
from mpx.cast import (
99
cast_tree,
1010
cast_to_float32,
1111
cast_to_float16,
@@ -14,7 +14,7 @@
1414
cast_to_half_precision,
1515
force_full_precision,
1616
)
17-
from mpfj.dtypes import HALF_PRECISION_DATATYPE
17+
from mpx.dtypes import HALF_PRECISION_DATATYPE
1818

1919

2020
class EQXModuleBase(eqx.Module):

tests/test_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import jax.numpy as jnp
3-
from mpfj.dtypes import half_precision_datatype, set_half_precision_datatype
3+
from mpx.dtypes import half_precision_datatype, set_half_precision_datatype
44

55
class TestDtypes(unittest.TestCase):
66
def test_default_half_precision(self):

tests/test_grad_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import jax.numpy as jnp
44
import equinox as eqx
55
import optax
6-
from mpfj.grad_tools import select_tree, filter_grad, filter_value_and_grad, optimizer_update
7-
from mpfj.loss_scaling import DynamicLossScaling
6+
from mpx.grad_tools import select_tree, filter_grad, filter_value_and_grad, optimizer_update
7+
from mpx.loss_scaling import DynamicLossScaling
88

99

1010
# Create a simple model for testing

tests/test_loss_scaling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import jax
33
import jax.numpy as jnp
44
import equinox as eqx
5-
from mpfj.loss_scaling import DynamicLossScaling, all_finite, scaled
5+
from mpx.loss_scaling import DynamicLossScaling, all_finite, scaled
66

77
class TestLossScaling(unittest.TestCase):
88
def setUp(self):

0 commit comments

Comments
 (0)