Replies: 1 comment 1 reply
-
Thanks for the report! Leaked tracers generally come from cacheing traced values within a JIT-compiled execution. The simplest version might look something like this: class Foo:
def func(self, x):
self.x = x
return x
f = Foo()
jax.jit(f.func)(1) Your code doesn't do any of this kind of cacheing explicitly, but it happens implicitly via your use of I think the fix would be to avoid using the Best of luck! |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to write a DP-SGD using JAX, and one of the steps involves using vmap to clip the gradients generated by each example in a batch. However, I am encountering an error message that says "trace leak." My error message is as follows:
Error message
Code is below:
I have located the error occurring at the line where the code runs to "clipped_grads, vs = vmap(clipped_grad,(None,0))(l2_norm_clip,batch)". However, I don't know how to modify my code. I sincerely hope to get some help.
Beta Was this translation helpful? Give feedback.
All reactions