You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I have a project which need convert an input image to a pyramid, then extract features from each scale in the pyramid using CNN. I'm using flax by the way.
In below code, I have tried pyramid_gaussian from skimage which can not handle a batch of images, and it is non-jit-able, which will trigger TracerArrayConversionError if we try to jit a function that calls this module.
So how to get image pyramid in jax?
Or maybe we can treat image pyramid as a data processing step, and use tensorflow or pytorch pyramid funtion to get a batch of image pyrmaid, and take the pyramid as input directly to the module, what do you think?
fromskimage.transformimportpyramid_gaussianclassFoo(nn.Module):
@nn.compactdef__call__(self, image):
pyra=tuple(pyramid_gaussian(image, downscale=2)) # here we use non-jit-able funcionpyra= [scale[None, ...] forscaleinpyra]
r=0.0forscaleinpyra:
x=nn.Conv(2, kernel_size=(1, 1), strides=(1, 1), padding="SAME")(scale)
r+=x.mean(axis=(1, 2, 3))
returnr
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have a project which need convert an input image to a pyramid, then extract features from each scale in the pyramid using CNN. I'm using flax by the way.
In below code, I have tried
pyramid_gaussian
fromskimage
which can not handle a batch of images, and it is non-jit-able, which will triggerTracerArrayConversionError
if we try to jit a function that calls this module.So how to get image pyramid in jax?
Or maybe we can treat image pyramid as a data processing step, and use tensorflow or pytorch pyramid funtion to get a batch of image pyrmaid, and take the pyramid as input directly to the module, what do you think?
Beta Was this translation helpful? Give feedback.
All reactions