Skip to content

Commit 2409d3d

Browse files
committed
fix
1 parent 9bfcf24 commit 2409d3d

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

deepmd/pt/utils/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,6 @@ def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ...
227227
def to_numpy_array(xx: None) -> None: ...
228228

229229

230-
@overload
231-
def to_numpy_array(xx: float) -> np.ndarray: ...
232-
233-
234230
def to_numpy_array(
235231
xx: torch.Tensor | np.ndarray | float | None,
236232
) -> np.ndarray | None:

deepmd/tf/utils/learning_rate.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
BaseLR,
1414
)
1515
from deepmd.tf.env import (
16+
GLOBAL_TF_FLOAT_PRECISION,
1617
tf,
1718
)
1819

@@ -90,13 +91,11 @@ def build(self, global_step: tf.Tensor, num_steps: int) -> tf.Tensor:
9091
self._base_lr = BaseLR(**params)
9192

9293
# === Step 2. Bind a numpy_function for runtime evaluation ===
93-
from deepmd.tf.env import (
94-
GLOBAL_TF_FLOAT_PRECISION,
95-
)
94+
base_lr = self._base_lr
9695

9796
def _lr_value(step: np.ndarray) -> np.ndarray:
9897
return np.asarray(
99-
self._base_lr.value(step),
98+
base_lr.value(step),
10099
dtype=GLOBAL_TF_FLOAT_PRECISION.as_numpy_dtype,
101100
)
102101

0 commit comments

Comments
 (0)