Skip to content

Commit a8cbddd

Browse files
QwlouseThe kauldron Authors
authored andcommitted
migrate kauldron to ktyping
PiperOrigin-RevId: 822524116
1 parent b8970f7 commit a8cbddd

30 files changed

+71
-58
lines changed

kauldron/data/deprecated.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
from kauldron.data import pipelines
3838
from kauldron.data import utils
3939
from kauldron.data.loaders import base as base_data_loader
40-
from kauldron.typing import PRNGKeyLike, PyTree # pylint: disable=g-importing-member,g-multiple-import
40+
from kauldron.ktyping import PyTree # pylint: disable=g-importing-member
41+
from kauldron.typing import PRNGKeyLike # pylint: disable=g-importing-member
4142
from kauldron.utils import config_util
4243
import tensorflow as tf
4344
import tensorflow_datasets as tfds

kauldron/data/pipelines.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from kauldron.data import data_utils
2727
from kauldron.data import iterators
2828
from kauldron.data import utils
29-
from kauldron.typing import PRNGKeyLike, PyTree # pylint: disable=g-importing-member,g-multiple-import
29+
from kauldron.ktyping import PyTree # pylint: disable=g-importing-member
30+
from kauldron.typing import PRNGKeyLike # pylint: disable=g-importing-member
3031
from kauldron.utils import config_util
3132

3233
# Output of `tfds.as_numpy`

kauldron/evals/eval_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
from kauldron.checkpoints import partial_loader
3333
from kauldron.evals import evaluators as evaluators_lib
3434
from kauldron.evals import run_strategies
35+
from kauldron.ktyping import PyTree # pylint: disable=g-importing-member
3536
from kauldron.train import auxiliaries
3637
from kauldron.train import train_step
3738
from kauldron.train import trainer_lib
38-
from kauldron.typing import PyTree # pylint: disable=g-importing-member
3939
from kauldron.utils import constants
4040
from kauldron.utils.status_utils import status # pylint: disable=g-importing-member
4141
import orbax.checkpoint as ocp

kauldron/evals/fewshot_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
from kauldron import data
2828
from kauldron import kontext
2929
from kauldron.evals import evaluators
30+
from kauldron.ktyping import Array, Float, Int, Scalar, check_type, typechecked # pylint: disable=g-multiple-import,g-importing-member
3031
from kauldron.metrics import base
3132
from kauldron.metrics import base_state
3233
from kauldron.train import auxiliaries
3334
from kauldron.train import train_step
34-
from kauldron.typing import Array, Float, Int, Scalar, check_type, typechecked # pylint: disable=g-multiple-import,g-importing-member
3535
from kauldron.utils import config_util
3636
from kauldron.utils import kdash
3737
from kauldron.utils import utils

kauldron/kd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from kauldron import train
4343
from kauldron import testing
4444
from kauldron import typing
45+
from kauldron import ktyping
4546
from kauldron.utils import api as utils
4647
from kauldron.utils import from_xid
4748
from kauldron.utils import kdash

kauldron/losses/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from jax import numpy as jnp
2626
from kauldron import kontext
2727
from kauldron import metrics
28+
from kauldron.ktyping import Array, Float, PyTree # pylint: disable=g-multiple-import,g-importing-member
2829
from kauldron.metrics import base_state
29-
from kauldron.typing import Array, Float, PyTree # pylint: disable=g-multiple-import,g-importing-member
3030

3131

3232
Schedule = Callable[[int], float]

kauldron/losses/simple.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import jax.numpy as jnp
2222
from kauldron import kontext
23+
from kauldron.ktyping import Array, Float, Int, typechecked # pylint: disable=g-multiple-import,g-importing-member
2324
from kauldron.losses import base
24-
from kauldron.typing import Array, Float, Int, typechecked # pylint: disable=g-multiple-import,g-importing-member
2525
import optax
2626

2727

@@ -80,6 +80,7 @@ def get_values(self, preds: Float["*a"], targets: Float["*a"]) -> Float["*a"]:
8080
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
8181
class Huber(base.Loss):
8282
"""Huber loss."""
83+
8384
delta: float = 1.0
8485

8586
preds: kontext.Key = kontext.REQUIRED
@@ -113,7 +114,8 @@ def get_values(
113114
targets = self._safe_normalize(targets)
114115

115116
similarity = jnp.sum(preds * targets, axis=-1, keepdims=True)
116-
return - similarity
117+
return -similarity
118+
117119

118120
# ============================== Classification ===============================
119121

kauldron/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
@_catch_post_mortem
6868
def main(_):
6969
tf.config.set_visible_devices([], "GPU")
70-
kd.typing.enable_kd_type_checking() # Enable custom checks before resolve
7170
eval_names = _EVAL_NAMES.value
7271
cfg = _CONFIG.value
7372
trainer: kd.train.Trainer = kd.konfig.resolve(cfg)

kauldron/metrics/auto_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
import jax
2727
import jax.numpy as jnp
2828
from kauldron import kontext
29+
from kauldron.ktyping import Array, PyTree # pylint: disable=g-multiple-import,g-importing-member
2930
from kauldron.metrics import base_state
3031
from kauldron.metrics.base_state import EMPTY # pylint: disable=g-importing-member
31-
from kauldron.typing import Array, PyTree # pylint: disable=g-multiple-import,g-importing-member
3232
import numpy as np
3333

3434
_MetricT = TypeVar("_MetricT")

kauldron/metrics/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
import flax
2525
import jax
2626
from kauldron import kontext
27+
from kauldron.ktyping import Float, PyTree # pylint: disable=g-multiple-import,g-importing-member
2728
from kauldron.metrics import base_state
28-
from kauldron.typing import Float, PyTree # pylint: disable=g-multiple-import,g-importing-member
2929

3030
_FnT = TypeVar("_FnT")
3131

0 commit comments

Comments
 (0)