1313 LearningRateExp ,
1414)
1515from deepmd .tf .env import (
16+ GLOBAL_TF_FLOAT_PRECISION ,
1617 tf ,
1718)
1819from deepmd .tf .utils .learning_rate import (
@@ -42,13 +43,13 @@ class TestLearningRateScheduleBuild(unittest.TestCase):
4243 """Test TF tensor building and integration."""
4344
4445 def test_build_returns_tensor (self ) -> None :
45- """Test that build() returns a float64 TF tensor (consistent with GLOBAL_TF_FLOAT_PRECISION) ."""
46+ """Test that build() returns a TF tensor with correct dtype ."""
4647 lr_schedule = LearningRateSchedule ({"start_lr" : 1e-3 , "stop_lr" : 1e-5 })
4748 global_step = tf .constant (0 , dtype = tf .int64 )
4849 lr_tensor = lr_schedule .build (global_step , num_steps = 10000 )
4950
5051 self .assertIsInstance (lr_tensor , tf .Tensor )
51- self .assertEqual (lr_tensor .dtype , tf . float64 )
52+ self .assertEqual (lr_tensor .dtype , GLOBAL_TF_FLOAT_PRECISION )
5253
5354 def test_default_type_exp (self ) -> None :
5455 """Test that default type is 'exp' when not specified."""
@@ -58,8 +59,8 @@ def test_default_type_exp(self) -> None:
5859
5960 self .assertIsInstance (lr_schedule .base_lr , LearningRateExp )
6061
61- def test_tensor_value_matches_base_lr (self ) -> None :
62- """Test that TF tensor value matches BaseLR.value()."""
62+ def test_value_method_matches_base_lr (self ) -> None :
63+ """Test that value() method matches BaseLR.value() after build ."""
6364 lr_schedule = LearningRateSchedule (
6465 {
6566 "start_lr" : 1e-3 ,
@@ -72,12 +73,11 @@ def test_tensor_value_matches_base_lr(self) -> None:
7273 global_step = tf .constant (test_step , dtype = tf .int64 )
7374 lr_schedule .build (global_step , num_steps = 10000 )
7475
75- # Use value() method which works in both graph and eager mode
76- # This indirectly verifies tensor computation matches BaseLR
77- tensor_value = lr_schedule .value (test_step )
76+ # value() method returns base_lr.value() as float
77+ method_value = lr_schedule .value (test_step )
7878 base_lr_value = lr_schedule .base_lr .value (test_step )
7979
80- np .testing .assert_allclose (tensor_value , base_lr_value , rtol = 1e-10 )
80+ np .testing .assert_allclose (method_value , base_lr_value , rtol = 1e-10 )
8181
8282 def test_start_lr_accessor (self ) -> None :
8383 """Test start_lr() accessor returns correct value."""
0 commit comments