Skip to content

Commit 551721a

Browse files
authored
address Copilot's suggestion
Refactor flax_module to use a generic type for better type safety. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 51985af commit 551721a

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

deepmd/jax/common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None:
4444
return jnp.array(array)
4545

4646

47+
T = TypeVar('T')
48+
49+
4750
def flax_module(
48-
module: type[NativeOP],
49-
) -> type[nnx.Module]:
51+
module: type[T],
52+
) -> type[T]: # runtime: actually returns type[T & nnx.Module]
5053
"""Convert a NativeOP to a Flax module.
5154
5255
Parameters

0 commit comments

Comments
 (0)