Skip to content

Commit afda975

Browse files
committed
Protect from crash as in issue lucasb-eyer#29
1 parent bbf9701 commit afda975

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

pydensecrf/densecrf.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ cdef extern from "densecrf/include/densecrf.h":
8686

8787
cdef class DenseCRF:
8888
cdef c_DenseCRF *_this
89+
cdef int _nlabel
90+
cdef int _nvar
8991

9092

9193
cdef class DenseCRF2D(DenseCRF):

pydensecrf/densecrf.pyx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,30 @@ cdef class DenseCRF:
5959
else:
6060
self._this = NULL
6161

62+
self._nvar = nvar
63+
self._nlabel = nlabels
64+
6265
def __dealloc__(self):
6366
# Because destructors are virtual, this is enough to delete any object
6467
# of child classes too.
6568
if self._this:
6669
del self._this
6770

6871
def addPairwiseEnergy(self, float[:,::1] features not None, compat, KernelType kernel=DIAG_KERNEL, NormalizationType normalization=NORMALIZE_SYMMETRIC):
72+
if features.shape[0] != self._nlabel or features.shape[1] != self._nvar:
73+
raise ValueError("Bad shape for pairwise energy (Need {}, got {})".format((self._nlabel, self._nvar), (features.shape[0], features.shape[1])))
74+
6975
self._this.addPairwiseEnergy(eigen.c_matrixXf(features), _labelcomp(compat), kernel, normalization)
7076

7177
def setUnary(self, Unary u):
7278
self._this.setUnaryEnergy(u.move())
7379

7480
def setUnaryEnergy(self, float[:,::1] u not None, float[:,::1] f = None):
81+
if u.shape[0] != self._nlabel or u.shape[1] != self._nvar:
82+
raise ValueError("Bad shape for unary energy (Need {}, got {})".format((self._nlabel, self._nvar), (u.shape[0], u.shape[1])))
83+
# TODO: I don't remember the exact shape `f` should have, so I'm not putting an assertion here.
84+
# If you get hit by a wrong shape of `f`, please open an issue with the necessary info!
85+
7586
if f is None:
7687
self._this.setUnaryEnergy(eigen.c_matrixXf(u))
7788
else:

tests/issue29.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# probs of shape 3d image per class: Nb_classes x Height x Width x Depth
2+
# assume the image has shape (69, 51, 72)
3+
import numpy as np
4+
import pydensecrf.densecrf as dcrf
5+
from pydensecrf.utils import unary_from_softmax, create_pairwise_gaussian
6+
7+
###
8+
9+
#shape = (69, 51, 72)
10+
#probs = np.random.randn(5, 69, 51).astype(np.float32)
11+
#probs /= probs.sum(axis=0, keepdims=True)
12+
#
13+
#d = dcrf.DenseCRF(np.prod(shape), probs.shape[0])
14+
#U = unary_from_softmax(probs)
15+
#print(U.shape)
16+
#d.setUnaryEnergy(U)
17+
#feats = create_pairwise_gaussian(sdims=(1.0, 1.0, 1.0), shape=shape)
18+
#d.addPairwiseEnergy(feats, compat=3, kernel=dcrf.FULL_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)
19+
#Q = d.inference(5)
20+
#new_image = np.argmax(Q, axis=0).reshape((shape[0], shape[1],shape[2]))
21+
22+
23+
###
24+
25+
SHAPE, NLABELS = (69, 51, 72), 5
26+
probs = np.random.randn(NLABELS, 68, 50).astype(np.float32) # WRONG shape here
27+
probs /= probs.sum(axis=0, keepdims=True)
28+
29+
d = dcrf.DenseCRF(np.prod(SHAPE), NLABELS)
30+
31+
d.setUnaryEnergy(unary_from_softmax(probs)) # THIS SHOULD THROW and not crash later
32+
feats = create_pairwise_gaussian(sdims=(1.0, 1.0, 1.0), shape=SHAPE)
33+
d.addPairwiseEnergy(feats, compat=3, kernel=dcrf.FULL_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)
34+
35+
Q = d.inference(5)
36+
new_image = np.argmax(Q, axis=0).reshape(SHAPE)

0 commit comments

Comments
 (0)