Explitly trigger jit recompilation #10106
-
After jitting a python function, Is there any way to manually re-jit of the inner, "unjitted", function? Some use cases are when decorating an intance method with E.g. class Properties(NamedTuple):
a: int
b: int
class Foo:
def __init__(self, properties: Properties):
self.properties = properties
@partial(jax.jit, static_argnums=0)
def update(self, x):
return self.properties.a / self.properties.b * x
class Bar(Foo):
def __init__(self, properties):
super().__init__(properties)
self.properties = self.properties._replace(a=3)
x = jax.numpy.array((2, 3))
properties = Properties(1, 2)
foo = Foo(properties)
y = foo.update(x)
assert jax.numpy.array_equal(y, 1 / 2 * x)
print(y)
bar = Bar(properties)
y = bar.update(x)
assert jax.numpy.array_equal(y, 3 / 2 * x) # fails
print(y) There might be design constraints for which The question is whether there is something like: |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
assert jax.numpy.array_equal(y, 3 / 2 * x) # fails
y = bar.update(x)
print(y) should be y = bar.update(x)
assert jax.numpy.array_equal(y, 3 / 2 * x) # fails
print(y) And it seems that you doesn't use from functools import partial
from typing import NamedTuple
import jax
class Properties(NamedTuple):
a: int
b: int
class Foo:
def __init__(self, properties: Properties):
self.properties = properties
@partial(jax.jit, static_argnums=0)
def update(self, x):
return self.properties.a / self.properties.b * x
class Bar(Foo):
def __init__(self, properties):
super().__init__(properties)
self.properties = self.properties._replace(a=3)
x = jax.numpy.array((2, 3))
properties = Properties(1, 2)
foo = Foo(properties)
y = foo.update(x)
assert jax.numpy.array_equal(y, 1 / 2 * x)
bar = Bar(properties)
y = bar.update(x)
assert jax.numpy.array_equal(y, 3 / 2 * x) # pass |
Beta Was this translation helpful? Give feedback.
-
In addition to the existing (correct) answer: if you want to use classes around |
Beta Was this translation helpful? Give feedback.
should be
And it seems that you doesn't use
_replace
correctly, sinceNamedTuple
is immutable.