Skip to content

Commit 9096c7b

Browse files
committed
Fix segfault on wrong compat dtype.
The exception is now correctly propagated instead.
1 parent 1c40e00 commit 9096c7b

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

pydensecrf/densecrf.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import eigen
88
cimport eigen
99

1010

11-
cdef LabelCompatibility* _labelcomp(compat):
11+
cdef LabelCompatibility* _labelcomp(compat) except NULL:
1212
if isinstance(compat, Number):
1313
return new PottsCompatibility(compat)
1414
elif memoryview(compat).ndim == 1:
@@ -17,6 +17,7 @@ cdef LabelCompatibility* _labelcomp(compat):
1717
return new MatrixCompatibility(eigen.c_matrixXf(compat))
1818
else:
1919
raise ValueError("LabelCompatibility of dimension >2 not meaningful.")
20+
return NULL # Important for the exception(s) to propagate!
2021

2122

2223
cdef class Unary:

pydensecrf/test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
import numpy as np
2-
import densecrf as dcrf
2+
import pydensecrf.densecrf as dcrf
3+
4+
# TODO: Make this real unit-tests some time in the future...
5+
6+
# Tests for specific issues
7+
###########################
8+
9+
# Via e-mail: crash when non-float32 compat
10+
d = dcrf.DenseCRF2D(10,10,2)
11+
d.setUnaryEnergy(np.ones((2,10*10), dtype=np.float32))
12+
compat = np.array([1.0, 2.0])
13+
try:
14+
d.addPairwiseBilateral(sxy=(3,3), srgb=(3,3,3), rgbim=np.zeros((10,10,3), np.uint8), compat=compat)
15+
d.inference(2)
16+
raise TypeError("Didn't raise an exception, but should because compat dtypes don't match!!")
17+
except ValueError:
18+
pass # That's what we want!
19+
20+
21+
# The following is not a really good unittest, but was the first tests.
22+
###########################
323

424
# d = densecrf.PyDenseCRF2D(3, 2, 3)
525
# U = np.full((3,6), 0.1, dtype=np.float32)
@@ -25,3 +45,4 @@
2545
d.addPairwiseBilateral(2, 2, img, 3)
2646
# d.addPairwiseBilateral(2, 2, img, 3)
2747
np.argmax(d.inference(10), axis=0).reshape(10,10)
48+

0 commit comments

Comments
 (0)