Weird error when JITting a method #19434
Replies: 2 comments
-
We have a section in the FAQ about using Please let me know if that doesn't answer your question! |
Beta Was this translation helpful? Give feedback.
-
Looking closer, I think this is a problem with pytree registration. In your code def tree_flatten(self):
children = (self.eps,) # arrays / dynamic values
aux_data = {} # static values
return children, aux_data Alternatively, you could store it as an def __init__(self, eps: int):
self.eps = int(eps) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Consider the following code:
It looks pretty innocuous but when I run it the last line raises an error. The error message is the following:
I noticed that if I remove the
jit
of theloss
method the error disappears. Why is that? Am I using it wrongly?I am using jax
0.4.23
and python3.10
on a SageMakerml.p3
instance.Beta Was this translation helpful? Give feedback.
All reactions