Skip to content
Discussion options

You must be logged in to vote
assert jax.numpy.array_equal(y, 3 / 2 * x)  # fails
y = bar.update(x)
print(y)

should be

y = bar.update(x)
assert jax.numpy.array_equal(y, 3 / 2 * x)  # fails
print(y)

And it seems that you doesn't use _replace correctly, since NamedTuple is immutable.

from functools import partial
from typing import NamedTuple

import jax

class Properties(NamedTuple):
    a: int
    b: int


class Foo:
    def __init__(self, properties: Properties):
        self.properties = properties

    @partial(jax.jit, static_argnums=0)
    def update(self, x):
        return self.properties.a / self.properties.b * x


class Bar(Foo):
    def __init__(self, properties):
        super().__init__(properties)
        s…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by epignatelli
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants