@@ -28,20 +28,19 @@ def __init__(self):
2828 def forward (self , x , y ):
2929 return self .loss (self .linear (x ).softmax (dim = 0 ), y )
3030
31- def get_random_inputs (self ):
32- return (torch .randn ( 3 ), torch .tensor ([1.0 , 0.0 , 0.0 ]))
31+ def get_inputs (self ):
32+ return (torch .ones ( 3 , dtype = torch . float32 ), torch .tensor ([1.0 , 0.0 , 0.0 ]))
3333
3434 def test (self ):
3535 m = self .ModuleSimpleTrain ()
36- ep = torch .export .export (m , m .get_random_inputs (), strict = True )
36+ ep = torch .export .export (m , m .get_inputs (), strict = True )
3737 ep = _export_forward_backward (ep )
3838 ep = to_edge (ep )
3939 ep = ep .to_executorch ()
4040 buffer = ep .buffer
4141 tm = _load_for_executorch_for_training_from_buffer (buffer )
4242
43- tm .forward_backward ("forward" , m .get_random_inputs ())
44- orig_param = list (tm .named_parameters ().values ())[0 ].clone ()
43+ orig_loss = tm .forward_backward ("forward" , m .get_inputs ())
4544 optimizer = get_sgd_optimizer (
4645 tm .named_parameters (),
4746 0.1 ,
@@ -50,7 +49,19 @@ def test(self):
5049 0 ,
5150 False ,
5251 )
52+
53+ cloned_params = list (tm .named_parameters ().values ())
54+ cloned_params = [p .clone () for p in cloned_params ]
55+
5356 optimizer .step (tm .named_gradients ())
54- self .assertFalse (
55- torch .allclose (orig_param , list (tm .named_parameters ().values ())[0 ])
56- )
57+
58+ # The python module caches the param tensors after the first
59+ # inference. So this doesn't test if the params are actually
60+ # updated in cpp world.
61+ for p , cloned_p in zip (tm .named_parameters ().values (), cloned_params ):
62+ self .assertFalse (torch .allclose (p , cloned_p ))
63+
64+ # Test that the params actually changed in cpp by running against
65+ # the same inputs again and seeing that the loss is different.
66+ second_loss = tm .forward_backward ("forward" , m .get_inputs ())
67+ self .assertFalse (torch .allclose (orig_loss [0 ], second_loss [0 ]))
0 commit comments