Skip to content

Commit c1c0b2c

Browse files
Copilotnjzjz
andcommitted
feat(jax): enable ANN rule and add comprehensive type hints to JAX backend
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 4064b3b commit c1c0b2c

File tree

3 files changed

+509
-590
lines changed

3 files changed

+509
-590
lines changed

deepmd/jax/utils/serialization.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
from pathlib import (
33
Path,
44
)
5-
from typing import (
6-
Any,
7-
)
85

96
import numpy as np
107
import orbax.checkpoint as ocp
@@ -58,7 +55,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
5855

5956
def exported_whether_do_atomic_virial(
6057
do_atomic_virial: bool, has_ghost_atoms: bool
61-
) -> Any:
58+
) -> jax_export.Exported:
6259
def call_lower_with_fixed_do_atomic_virial(
6360
coord: jnp.ndarray,
6461
atype: jnp.ndarray,

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,6 @@ runtime-evaluated-base-classes = ["torch.nn.Module"]
425425
"deepmd/tf/**" = ["TID253", "ANN"]
426426
"deepmd/pt/**" = ["TID253"]
427427
"deepmd/jax/**" = ["TID253"]
428-
"deepmd/jax/jax2tf/**" = ["TID253", "ANN"]
429428
"deepmd/pd/**" = ["TID253", "ANN"]
430429

431430
"source/**" = ["ANN"]

0 commit comments

Comments
 (0)