using scan with multiple layer - each with diffrent initialization - working example #14715
Unanswered
jakubMitura14
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Hey @jakubMitura14, I added some
import flax.linen as nn
import jax
import jax.numpy as jnp
class Layer(nn.Module):
dummy: int
@nn.compact
def __call__(self, c, x, t):
jax.debug.print("x: {x}", x=x)
jax.debug.print("t: {t}", t=t)
x = nn.Dense(len(x))(x)
x = jax.nn.softmax(jnp.exp(t) * x)
c = c + jnp.sum(jnp.ravel(x))
jax.debug.print("c: {c}\n", c=c.flatten())
return c, x
class Model(nn.Module):
@nn.compact
def __call__(self, x, t):
LayerScanned = nn.scan(
Layer,
variable_axes={"params": 0, "dummy": 0},
split_rngs={"params": False},
length=5,
in_axes=(0, nn.broadcast),
out_axes=0,
)
carry = jnp.zeros_like(x)
carry, y = LayerScanned(10)(carry, x, t)
jax.debug.print("final y: {y}", y=y)
jax.debug.print("final c: {c}", c=carry)
return y, carry
x = jax.random.uniform(jax.random.PRNGKey(0), (5, 2))
model = Model()
print("\ninit")
params = model.init(jax.random.PRNGKey(0), x, t=1.0)
print("\napply")
y, c = model.apply(params, x, t=0.1)
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
The post is also published in flax forum https://github.com/google/flax/discussions/2912 - so this one can be removed by administrator
I used google/flax#2127 example a a basis
I want to have multiple layers of the same type but with diffrent value of a dummy variable.
Then sequentially invoke each layer, summing theit outputs to a single number.
Main problem is here is that when I pass array to LayerScanned I got the same array broadcasted to each layer instead of hetting diffrent inteer for each.
I also for some reason get the array as a final result not cumulative sum (although this is a minor problem)
Beta Was this translation helpful? Give feedback.
All reactions