Skip to content

Gradient flowing between inference_cell and generator_cell #42

@Rustleman

Description

@Rustleman

There seems to be an issue in the inference_rnn function where the inference_cell and generator_cell are connected together:

tf-gqn/gqn/gqn_draw.py

Lines 417 to 421 in bc84f24

# estimate statistics and sample state from posterior
mu_q, sigma_q, z_q = compute_eta_and_sample_z(inf_state.lstm.h,
scope="Sample_eta_q")
# input into generator RNN
gen_input = _GeneratorCellInput(representations, query_poses, z_q)

It looks like the gradient flows through z_q.

Adding the line z_q = tf.stop_gradient(z_q)
seems to improve the results when just the generator_rnn is used during testing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions