Skip to content
Discussion options

You must be logged in to vote

IIUC, custom Pytree exactly meets your requirements.

from typing import Hashable
from dataclasses import dataclass
import jax.numpy as jnp
from jax import tree_util

@dataclass
class ProblemDef:
    x: jnp.ndarray
    a: jnp.ndarray
    b: Hashable # Note: ndarray is not hashable, shim is needed
    c: Hashable
    d: Hashable
    
    def tree_flatten(self):
        children = (self.x, self.a) # arrays / dynamic values
        aux_data = (self.b, self.c, self.d) # hashable static values
        return children, aux_data
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        x, a = children
        b, c, d = aux_data
        return cls(x, a, b, c, d)

tree_util.reg…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@ntenenz
Comment options

@YouJiacheng
Comment options

@ntenenz
Comment options

@YouJiacheng
Comment options

Answer selected by ntenenz
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants