|
| 1 | + |
| 2 | +# coding: utf-8 |
| 3 | + |
| 4 | +# In[1]: |
| 5 | + |
| 6 | +# import sys |
| 7 | +# sys.path.insert(0,'/home/dlr16/Applications/anaconda2/envs/PyDenseCRF/lib/python2.7/site-packages') |
| 8 | + |
| 9 | + |
| 10 | +# In[2]: |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import matplotlib.pyplot as plt |
| 14 | +# get_ipython().magic(u'matplotlib inline') |
| 15 | +plt.rcParams['figure.figsize'] = (20, 20) |
| 16 | +plt.rcParams['image.interpolation'] = 'nearest' |
| 17 | +plt.rcParams['image.cmap'] = 'gray' |
| 18 | + |
| 19 | +import pydensecrf.densecrf as dcrf |
| 20 | +from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian |
| 21 | + |
| 22 | + |
| 23 | +# ## Start from scratch |
| 24 | + |
| 25 | +# In[3]: |
| 26 | + |
| 27 | +from scipy.stats import multivariate_normal |
| 28 | + |
| 29 | +x, y = np.mgrid[0:512, 0:512] |
| 30 | +pos = np.empty(x.shape + (2,)) |
| 31 | +pos[:, :, 0] = x; pos[:, :, 1] = y |
| 32 | +rv = multivariate_normal([256, 256], 128*128) |
| 33 | + |
| 34 | + |
| 35 | +# In[4]: |
| 36 | + |
| 37 | +probs = rv.pdf(pos) |
| 38 | +probs = (probs-probs.min()) / (probs.max()-probs.min()) |
| 39 | +probs = 0.2 * (probs-0.5) + 0.5 |
| 40 | +probs = np.tile(probs[:,:,np.newaxis],(1,1,2)) |
| 41 | +probs[:,:,1] = 1 - probs[:,:,0] |
| 42 | +# plt.plot(probs[256,:,0]) |
| 43 | + |
| 44 | +# transpose for graph |
| 45 | +probs = np.transpose(probs,(2,0,1)) |
| 46 | + |
| 47 | + |
| 48 | +# In[17]: |
| 49 | + |
| 50 | +# XX:IF NCHANNELS != 3, I GET ERRONEOUS OUTPUT |
| 51 | +nchannels=4 |
| 52 | + |
| 53 | +U = unary_from_softmax(probs) # note: num classes is first dim |
| 54 | +d = dcrf.DenseCRF2D(probs.shape[1],probs.shape[2],probs.shape[0]) |
| 55 | +d.setUnaryEnergy(U) |
| 56 | + |
| 57 | +Q_Unary = d.inference(10) |
| 58 | +map_soln_Unary = np.argmax(Q_Unary, axis=0).reshape((probs.shape[1],probs.shape[2])) |
| 59 | + |
| 60 | +tmp_img = np.zeros((probs.shape[1],probs.shape[2],nchannels)).astype(np.uint8) |
| 61 | +tmp_img[150:362,150:362,:] = 1 |
| 62 | + |
| 63 | +energy = create_pairwise_bilateral(sdims=(10,10), schan=0.01, img=tmp_img, chdim=2) |
| 64 | +d.addPairwiseEnergy(energy, compat=10) |
| 65 | + |
| 66 | +# This is wrong and will now raise a ValueError: |
| 67 | +#d.addPairwiseBilateral(sxy=(10,10), |
| 68 | +# srgb=0.01, |
| 69 | +# rgbim=tmp_img, |
| 70 | +# compat=10) |
| 71 | + |
| 72 | +Q = d.inference(100) |
| 73 | +map_soln = np.argmax(Q, axis=0).reshape((probs.shape[1],probs.shape[2])) |
| 74 | + |
| 75 | +plt.subplot(2,2,1) |
| 76 | +plt.imshow(probs[0,:,:]) |
| 77 | +plt.colorbar() |
| 78 | +plt.subplot(2,2,2) |
| 79 | +plt.imshow(map_soln_Unary) |
| 80 | +plt.colorbar() |
| 81 | +plt.subplot(2,2,3) |
| 82 | +plt.imshow(tmp_img[:,:,0]) |
| 83 | +plt.colorbar() |
| 84 | +plt.subplot(2,2,4) |
| 85 | +plt.imshow(map_soln) |
| 86 | +plt.colorbar() |
| 87 | +plt.show() |
0 commit comments