-
Notifications
You must be signed in to change notification settings - Fork 40
Open
Description
There seems to be an issue in the inference_rnn function where the inference_cell and generator_cell are connected together:
Lines 417 to 421 in bc84f24
| # estimate statistics and sample state from posterior | |
| mu_q, sigma_q, z_q = compute_eta_and_sample_z(inf_state.lstm.h, | |
| scope="Sample_eta_q") | |
| # input into generator RNN | |
| gen_input = _GeneratorCellInput(representations, query_poses, z_q) |
It looks like the gradient flows through z_q.
Adding the line z_q = tf.stop_gradient(z_q)
seems to improve the results when just the generator_rnn is used during testing.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels