@@ -41,27 +41,41 @@ fn main() -> Result<(), Box<dyn Error>> {
4141
4242 // Load the saved model exported by regression_savedmodel.py.
4343 let mut graph = Graph :: new ( ) ;
44- let session =
45- SavedModelBundle :: load ( & SessionOptions :: new ( ) , & [ "serve" ] , & mut graph, export_dir) ?. session ;
46- let op_x = graph. operation_by_name_required ( "train_x" ) ?;
47- let op_y = graph. operation_by_name_required ( "train_y" ) ?;
48- let op_train = graph. operation_by_name_required ( "StatefulPartitionedCall" ) ?;
49- let op_w = graph. operation_by_name_required ( "StatefulPartitionedCall_1" ) ?;
50- let op_b = graph. operation_by_name_required ( "StatefulPartitionedCall_1" ) ?;
44+ let bundle =
45+ SavedModelBundle :: load ( & SessionOptions :: new ( ) , & [ "serve" ] , & mut graph, export_dir) ?;
46+ let session = & bundle. session ;
47+
48+ let train_signature = bundle. meta_graph_def ( ) . get_signature ( "train" ) ?;
49+ let x_info = train_signature. get_input ( "x" ) ?;
50+ let y_info = train_signature. get_input ( "y" ) ?;
51+ let train_info = train_signature. get_output ( "train" ) ?;
52+ let op_x = graph. operation_by_name_required ( & x_info. name ( ) . name ) ?;
53+ let op_y = graph. operation_by_name_required ( & y_info. name ( ) . name ) ?;
54+ let op_train = graph. operation_by_name_required ( & train_info. name ( ) . name ) ?;
55+ let w_info = bundle
56+ . meta_graph_def ( )
57+ . get_signature ( "w" ) ?
58+ . get_output ( "output" ) ?;
59+ let op_w = graph. operation_by_name_required ( & w_info. name ( ) . name ) ?;
60+ let b_info = bundle
61+ . meta_graph_def ( )
62+ . get_signature ( "b" ) ?
63+ . get_output ( "output" ) ?;
64+ let op_b = graph. operation_by_name_required ( & b_info. name ( ) . name ) ?;
5165
5266 // Train the model (e.g. for fine tuning).
5367 let mut train_step = SessionRunArgs :: new ( ) ;
5468 train_step. add_feed ( & op_x, 0 , & x) ;
5569 train_step. add_feed ( & op_y, 0 , & y) ;
56- train_step. request_fetch ( & op_train, 0 ) ;
70+ train_step. add_target ( & op_train) ;
5771 for _ in 0 ..steps {
5872 session. run ( & mut train_step) ?;
5973 }
6074
6175 // Grab the data out of the session.
6276 let mut output_step = SessionRunArgs :: new ( ) ;
6377 let w_ix = output_step. request_fetch ( & op_w, 0 ) ;
64- let b_ix = output_step. request_fetch ( & op_b, 1 ) ;
78+ let b_ix = output_step. request_fetch ( & op_b, 0 ) ;
6579 session. run ( & mut output_step) ?;
6680
6781 // Check our results.
0 commit comments