-
When training a simple DC-GAN in TensorFlow, you can do one forward pass with the generator and two with the discriminator per train step by using multiple gradient tapes. Example from the TF DC-GAN tutorial: with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) In JAX, you need 2 different functions for the separate sets of gradients. Without JIT, this results in 2 generator calls and 3 discriminator calls. Below is some skeleton code (using Flax): @jax.jit
def train_step(g_params, d_params, real_imgs, rng):
batch_size = real_imgs.shape[0]
noise = jax.random.normal(rng, (batch_size, 100))
real_labels = jnp.ones((batch_size, 1))
fake_labels = jnp.zeros((batch_size, 1))
def g_loss_fn(gp):
# SUBEXPR START
fake_imgs = generator.apply(gp, noise)
fake_preds = discriminator.apply(d_params, fake_imgs)
# SUBEXPR END
return bce_from_logits(fake_preds, real_labels).mean()
def d_loss_fn(dp):
# SUBEXPR START
fake_imgs = generator.apply(g_params, noise)
fake_preds = discriminator.apply(dp, fake_imgs)
# SUBEXPR END
real_preds = discriminator.apply(dp, real_imgs)
fake_loss = bce_from_logits(fake_preds, fake_labels).mean()
real_loss = bce_from_logits(real_preds, real_labels).mean()
return (fake_loss + real_loss) / 2.0
g_grads = jax.grad(g_loss_fn)(g_params)
d_grads = jax.grad(d_loss_fn)(d_params)
return jax.tree_multimap(
lambda g, p: p - 0.1 * g,
(g_grads, d_grads),
(g_params, d_params),
) Full notebook here. My question is: when JIT'd, does XLA combine the common subexpressions between |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Based on timings, it seems like the answer is yes, JAX will combine those common subexpressions (timings here). The original def g_loss_fn(gp):
fake_imgs = generator.apply(gp, noise)
fake_preds = discriminator.apply(d_params, fake_imgs)
return bce_from_logits(fake_preds, real_labels).mean(), fake_imgs
# added this line
(_, fake_imgs), g_grads = jax.value_and_grad(g_loss_fn, has_aux=True)(g_params)
def d_loss_fn(dp):
# removed generator forward pass HERE
fake_preds = discriminator.apply(dp, fake_imgs)
real_preds = discriminator.apply(dp, real_imgs)
fake_loss = bce_from_logits(fake_preds, fake_labels).mean()
real_loss = bce_from_logits(real_preds, real_labels).mean()
return (fake_loss + real_loss) / 2.0
d_grads = jax.grad(d_loss_fn)(d_params) Then I made the following modification to remove the additional def g_loss_fn(gp, dp):
fake_imgs = generator.apply(gp, noise)
fake_preds = discriminator.apply(dp, fake_imgs)
return bce_from_logits(fake_preds, real_labels).mean(), fake_preds
def d_loss_fn(dp, gp):
(_, fake_preds), g_grads = jax.value_and_grad(g_loss_fn, has_aux=True)(gp, dp)
real_preds = discriminator.apply(dp, real_imgs)
fake_loss = bce_from_logits(fake_preds, fake_labels).mean()
real_loss = bce_from_logits(real_preds, real_labels).mean()
return (fake_loss + real_loss) / 2.0, (fake_preds, g_grads)
(_, (_, g_grads)), d_grads = jax.value_and_grad(d_loss_fn, has_aux=True)(d_params, g_params) And the timings for all three step functions were about the same (tried a much bigger model with a large batch size, shown in the notebook). |
Beta Was this translation helpful? Give feedback.
-
XLA does CSE among other optimizations, and I suspect that it will do something like it here, but it gives no universal guarantees to this end. Compiler optimizations often consider trading off various resources (e.g. memory vs. computation). In principle they could even decide to carry out redundant computation, e.g. in order to re-materialize an intermediate value rather than holding it in memory while scheduling other operations. As for how to find out post hoc whether an optimization seems to have taken place, you're right to consider profiling tools, XLA output, or even timing against expectations. But note again that, in principle, even for a fixed JAX-side program, the compiler's choices might change in the future. Using We don't have a public API today for viewing post-optimization HLO as text. Using internal functions (no support or stability guarantee!), @hawkinsp has a workaround along the following lines: import jax
import jax.numpy as jnp
def optimized_hlo(f, *example_args):
c = jax.xla_computation(f)(*example_args)
e = jax.lib.xla_bridge.get_backend().compile(c)
return e.hlo_modules()[0]
def f(x): return jnp.sin(jnp.cos(x))
print(optimized_hlo(f, 1.).to_string()) Once @zhangqiaorjc and I finish our work on introducing an ahead-of-time compile function (#6034), we could consider hanging a post-optimization HLO interface off of its output. |
Beta Was this translation helpful? Give feedback.
XLA does CSE among other optimizations, and I suspect that it will do something like it here, but it gives no universal guarantees to this end. Compiler optimizations often consider trading off various resources (e.g. memory vs. computation). In principle they could even decide to carry out redundant computation, e.g. in order to re-materialize an intermediate value rather than holding it in memory while scheduling other operations.
As for how to find out post hoc whether an optimization seems to have taken place, you're right to consider profiling tools, XLA output, or even timing against expectations. But note again that, in principle, even for a fixed JAX-side program, the compiler's c…