-
Hi all, I often find it convenient to define different levels of functionality in terms of inheritence and to override some class methods depending on the use case. In practise, in my current project I have some loss function F(params) which needs to be computed differently depending on the nature of "params", the latter describing parameters of another parameter which can be either linear or nonlinear. Many subroutines used to compute F are common between the linear and nonlinear case, but some are not, and when I was on Pytorch I used to heavily rely on class interfaces and redefine only the methods that differed in child classes, ie. I would have had some BaseF class with common subroutines, then LinearF(BaseF) and NonLinearF(BaseF) that only redefined what differs. This whole thing doesn't seem to fit well withing Jax philosophy. On this page https://jax.readthedocs.io/en/latest/jax-101/07-state.html I read at the end:
The following alternatives seem a bit verbose to me:
Is there another standard alternative that I'm not seeing ? One thing I have in mind is using inheritence but with only static methods (an no abstract methods). But this seems like the most degenarate case for using classes. Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 1 reply
-
I think it works with the classmethod! I thought I was gonna run into trouble with the implicit first argument "cls" when using jax functional transforms, but apparently that's handled appropriately. Still, if there's another, more standard way of proceeding, I'd be glad to know about it. |
Beta Was this translation helpful? Give feedback.
-
I find that class parent:
foo = jax.jit(lambda x: x + 1)
class childA(parent):
bar = jax.jit(lambda x: x + 2)
class childB(parent):
foo = jax.jit(lambda x: x + 3)
print(parent.foo(0)) # 1
print(childA.foo(0)) # 1
print(childA.bar(0)) # 2
print(childB.foo(0)) # 3 |
Beta Was this translation helpful? Give feedback.
-
You may find Equinox interesting. It's framed as a "neural network library" but it's actually something a bit more general: a parameterised function library. In particular this gives a reasonable class-based/inheritance-supporting way of working with JAX. As an example quite close to your own: Diffrax uses Equinox very extensively to define hierarchies of abstract parameterised functions. For example, you can define: import abc
import equinox as eqx
import jax.numpy as jnp
class AbstractFoo(eqx.Module):
param: jnp.ndarray
def bar(self, other):
return self.param + other + self.baz()
@abc.abstractmethod
def baz(self):
pass
class MyFoo(AbstractFoo):
another_param: jnp.ndarray
def baz(self):
return self.another_param Which you can use in the way you expect: my_foo = MyFoo(jnp.array(1.0), jnp.array(2.0))
out = my_foo.bar(3) and moreover can interop with import jax
@jax.jit
@jax.grad
def call(foo, x):
return foo.bar(x)
call(my_foo, 2.0) |
Beta Was this translation helpful? Give feedback.
I think it works with the classmethod! I thought I was gonna run into trouble with the implicit first argument "cls" when using jax functional transforms, but apparently that's handled appropriately. Still, if there's another, more standard way of proceeding, I'd be glad to know about it.