Skip to content

Commit ad40b00

Browse files
committed
fix test bugs
1 parent 947e8a6 commit ad40b00

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

source/tests/pd/model/test_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _get_dp_placeholders(self, dataset):
244244
prec = tf.float64
245245
place_holders[kk] = tf.placeholder(prec, [None], name="t_" + kk)
246246
place_holders["find_" + kk] = tf.placeholder(
247-
tf.float32, name="t_find_" + kk
247+
tf.float64, name="t_find_" + kk
248248
)
249249
place_holders["type"] = tf.placeholder(tf.int32, [None], name="t_type")
250250
place_holders["natoms_vec"] = tf.placeholder(
@@ -303,7 +303,12 @@ def test_consistency(self) -> None:
303303
},
304304
)
305305
my_model.to(DEVICE)
306-
my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.num_steps)
306+
my_lr = MyLRExp(
307+
self.start_lr,
308+
self.stop_lr,
309+
decay_steps=self.decay_steps,
310+
num_steps=self.num_steps,
311+
)
307312
my_loss = EnergyStdLoss(
308313
starter_learning_rate=self.start_lr,
309314
start_pref_e=self.start_pref_e,

source/tests/pt/model/test_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _get_dp_placeholders(self, dataset):
244244
prec = tf.float64
245245
place_holders[kk] = tf.placeholder(prec, [None], name="t_" + kk)
246246
place_holders["find_" + kk] = tf.placeholder(
247-
tf.float32, name="t_find_" + kk
247+
tf.float64, name="t_find_" + kk
248248
)
249249
place_holders["type"] = tf.placeholder(tf.int32, [None], name="t_type")
250250
place_holders["natoms_vec"] = tf.placeholder(
@@ -303,7 +303,12 @@ def test_consistency(self) -> None:
303303
},
304304
)
305305
my_model.to(DEVICE)
306-
my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.num_steps)
306+
my_lr = MyLRExp(
307+
self.start_lr,
308+
self.stop_lr,
309+
decay_steps=self.decay_steps,
310+
num_steps=self.num_steps,
311+
)
307312
my_loss = EnergyStdLoss(
308313
starter_learning_rate=self.start_lr,
309314
start_pref_e=self.start_pref_e,

source/tests/tf/test_lr.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
LearningRateExp,
1414
)
1515
from deepmd.tf.env import (
16+
GLOBAL_TF_FLOAT_PRECISION,
1617
tf,
1718
)
1819
from deepmd.tf.utils.learning_rate import (
@@ -42,13 +43,13 @@ class TestLearningRateScheduleBuild(unittest.TestCase):
4243
"""Test TF tensor building and integration."""
4344

4445
def test_build_returns_tensor(self) -> None:
45-
"""Test that build() returns a float64 TF tensor (consistent with GLOBAL_TF_FLOAT_PRECISION)."""
46+
"""Test that build() returns a TF tensor with correct dtype."""
4647
lr_schedule = LearningRateSchedule({"start_lr": 1e-3, "stop_lr": 1e-5})
4748
global_step = tf.constant(0, dtype=tf.int64)
4849
lr_tensor = lr_schedule.build(global_step, num_steps=10000)
4950

5051
self.assertIsInstance(lr_tensor, tf.Tensor)
51-
self.assertEqual(lr_tensor.dtype, tf.float64)
52+
self.assertEqual(lr_tensor.dtype, GLOBAL_TF_FLOAT_PRECISION)
5253

5354
def test_default_type_exp(self) -> None:
5455
"""Test that default type is 'exp' when not specified."""
@@ -58,8 +59,8 @@ def test_default_type_exp(self) -> None:
5859

5960
self.assertIsInstance(lr_schedule.base_lr, LearningRateExp)
6061

61-
def test_tensor_value_matches_base_lr(self) -> None:
62-
"""Test that TF tensor value matches BaseLR.value()."""
62+
def test_value_method_matches_base_lr(self) -> None:
63+
"""Test that value() method matches BaseLR.value() after build."""
6364
lr_schedule = LearningRateSchedule(
6465
{
6566
"start_lr": 1e-3,
@@ -72,12 +73,11 @@ def test_tensor_value_matches_base_lr(self) -> None:
7273
global_step = tf.constant(test_step, dtype=tf.int64)
7374
lr_schedule.build(global_step, num_steps=10000)
7475

75-
# Use value() method which works in both graph and eager mode
76-
# This indirectly verifies tensor computation matches BaseLR
77-
tensor_value = lr_schedule.value(test_step)
76+
# value() method returns base_lr.value() as float
77+
method_value = lr_schedule.value(test_step)
7878
base_lr_value = lr_schedule.base_lr.value(test_step)
7979

80-
np.testing.assert_allclose(tensor_value, base_lr_value, rtol=1e-10)
80+
np.testing.assert_allclose(method_value, base_lr_value, rtol=1e-10)
8181

8282
def test_start_lr_accessor(self) -> None:
8383
"""Test start_lr() accessor returns correct value."""

0 commit comments

Comments
 (0)