Best Practices for Composition of Configuration #10714
-
Given the following pseudocode: def g(x, a, b, c, d):
# assume b, c, d must be static for jit but not a
pass
def f(x1, a1, b1, c1, d1, x2, a2, b2, c2, d2, x3, a3, b3, c3, d3):
out1 = g(x1, a1, b1, c1, d1)
out2 = g(x2, a2, b2, c2, d2)
out3 = g(x3, a3, b3, c3, d3)
return out1 + out2 + out3
fjit = jax.jit(f, static_argnames=["b1", ...])
fjit(...) The function signature of def f(x1, x2, x3, config1, config2, config3):
...
fjit = jax.jit(f, static_argnames=["config1", "config2", "config3"] In this case however, we still have these Is there an alternative, established best practice for handling scenarios where some level of "static-ness" is required without resorting to flattened function signatures? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
IIUC, custom Pytree exactly meets your requirements. from typing import Hashable
from dataclasses import dataclass
import jax.numpy as jnp
from jax import tree_util
@dataclass
class ProblemDef:
x: jnp.ndarray
a: jnp.ndarray
b: Hashable # Note: ndarray is not hashable, shim is needed
c: Hashable
d: Hashable
def tree_flatten(self):
children = (self.x, self.a) # arrays / dynamic values
aux_data = (self.b, self.c, self.d) # hashable static values
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
x, a = children
b, c, d = aux_data
return cls(x, a, b, c, d)
tree_util.register_pytree_node_class(ProblemDef)
# Note: You should not mark ProblemDef as static argument See also: |
Beta Was this translation helpful? Give feedback.
IIUC, custom Pytree exactly meets your requirements.