Selecting a submodule from a list according to a tracer #12193
Unanswered
icysapphire
asked this question in
Q&A
Replies: 1 comment
-
Hi - the short answer is no, you cannot use a tracer to index into a list. But in the specific case of selecting from a list of functions to apply based on a traced value, # return self.pool_conv[conv_idx](sample)
return jax.lax.switch(conv_idx, self.pool_conv, sample) There is one caveat: |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a pool of convolutional layers and I would like to select which layer to use depending on the value of a tracer. The following is the sketch of my custom module:
The traceback:
I am a newbie and I am trying to figure out a possible implementation considering JAX compilation model.. can you help me? It seems that static args do not help here..
Beta Was this translation helpful? Give feedback.
All reactions