You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm wondering what's the recommended way to deal with variable input shapes.
Suppose the input shape = (batch_size, K, N, 4) where K is different in samples. The last 2 dim with shape=(N, 4) passes through the same CNN and outputs shape=(result_x, result_y). Then for each sample of shape (K, N, 4) -> (K, result_x, result_y) sum over K to get (result_x, result_y). Is there a good way to implement this?
I've tried padding all samples to have the same K, but it seems to take up a lot of memory. In my dataset most Ks are very small with a few very large ones.
My second attempt is to allow variable input and vmap the cnn over axis K, but does it mean the program recompiles every time a different K is passed?
Here's the codes for my second attempt, if it helps explain things clearer. The non-jitted & jitted performance doesn't seem to differ that much. Thanks a lot in advance!
import optax
from flax import linen as nn
import jax
from functools import partial
from tqdm import tqdm
import tensorflow as tf
import jax.numpy as jnp
from time import perf_counter
N=1344
batch_size=32
out_shape=(N//2,2)
class CNN(nn.Module):
def cnn(self, xx):
xx = nn.Conv(features=2,
kernel_size=(2,),
use_bias=False)(xx)
xx = nn.max_pool(xx,
window_shape = (2,),
strides = (2,),
padding='SAME')
return xx
@nn.compact
def __call__(self, x):
x = jax.vmap(self.cnn, in_axes=1, out_axes=1)(x)
x = jnp.sum(x, axis=1)
return x
class generator:
def __init__(self) -> None:
self.i=1
return
def __call__(self):
while self.i <= batch_size:
x_tf = jnp.asarray([[jnp.full((N,4), self.i)]*self.i]*batch_size) #(batch_size, ?, N, 4)
y_tf = jnp.ones((batch_size, *out_shape))
self.i+=1
yield x_tf, y_tf
def get_ds():
ds = tf.data.Dataset.from_generator(
generator(),
output_signature=(
tf.TensorSpec(shape=(batch_size,None, N,4), dtype=tf.int8),
tf.TensorSpec(shape=(batch_size,*out_shape), dtype=tf.float32)
)
)
return ds
def init_model(init_rng=jax.random.PRNGKey(0)):
model = CNN()
params = model.init(init_rng, jnp.zeros((2,N,4)))
return model, params['params']
def train_step(model, batch, params, step_size=0.01):
def mse_loss(params):
y_pred = model.apply({'params': params}, batch['x'])
loss = optax.l2_loss(predictions=y_pred, targets=batch['y']).mean()
return loss
loss_fn = jax.value_and_grad(mse_loss, has_aux=False)
_loss, grad = loss_fn(params)
params = jax.tree_map(lambda x, gr: x-step_size*gr, params, grad)
return params
train_step_jit = jax.jit(train_step, static_argnums=0)
def main():
model, params = init_model()
t_start_1 = perf_counter()
for epoch in range(3):
ds = get_ds()
for batch in tqdm(ds):
batch = {'x': batch[0].numpy(), 'y': batch[1].numpy()}
params = train_step(model, batch, params)
t_stop_1 = perf_counter()
t_start_2 = perf_counter()
for epoch in range(3):
ds = get_ds()
for batch in tqdm(ds):
batch = {'x': batch[0].numpy(), 'y': batch[1].numpy()}
params = train_step_jit(model, batch, params)
t_stop_2 = perf_counter()
print(f'not_jit: {t_stop_1-t_start_1}\njit: {t_stop_2-t_start_2}' )
return
if __name__=='__main__':
main()
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm wondering what's the recommended way to deal with variable input shapes.
Suppose the input shape = (batch_size, K, N, 4) where K is different in samples. The last 2 dim with shape=(N, 4) passes through the same CNN and outputs shape=(result_x, result_y). Then for each sample of shape (K, N, 4) -> (K, result_x, result_y) sum over K to get (result_x, result_y). Is there a good way to implement this?
I've tried padding all samples to have the same K, but it seems to take up a lot of memory. In my dataset most Ks are very small with a few very large ones.
My second attempt is to allow variable input and vmap the cnn over axis K, but does it mean the program recompiles every time a different K is passed?
Here's the codes for my second attempt, if it helps explain things clearer. The non-jitted & jitted performance doesn't seem to differ that much. Thanks a lot in advance!
The output looks like this:
Beta Was this translation helpful? Give feedback.
All reactions