@@ -517,15 +517,17 @@ cdef class Tree:
517517
518518 node_ndarray = d[' nodes' ]
519519 value_ndarray = d[' values' ]
520-
521- value_shape = (node_ndarray.shape[0 ], self .n_outputs)
520+
522521 if (node_ndarray.ndim != 1 or
523522 node_ndarray.dtype != NODE_DTYPE or
524- not node_ndarray.flags.c_contiguous or
525- value_ndarray.shape != value_shape or
523+ not node_ndarray.flags.c_contiguous):
524+ raise ValueError (' Did not recognise loaded array layout for `node_ndarray`' )
525+
526+ value_shape = (node_ndarray.shape[0 ], self .n_outputs, self .max_n_classes)
527+ if (value_ndarray.shape != value_shape or
526528 not value_ndarray.flags.c_contiguous or
527529 value_ndarray.dtype != np.float64):
528- raise ValueError (' Did not recognise loaded array layout' )
530+ raise ValueError (' Did not recognise loaded array layout for `value_ndarray` ' )
529531
530532 self .capacity = node_ndarray.shape[0 ]
531533 if self ._resize_c(self .capacity) != 0 :
@@ -541,15 +543,15 @@ cdef class Tree:
541543 if (jac_ndarray.shape != jac_shape or
542544 not jac_ndarray.flags.c_contiguous or
543545 jac_ndarray.dtype != np.float64):
544- raise ValueError (' Did not recognise loaded array layout' )
546+ raise ValueError (' Did not recognise loaded array layout for `jac_ndarray` ' )
545547 jac = memcpy(self .jac, (< np.ndarray> jac_ndarray).data,
546548 self .capacity * self .jac_stride * sizeof(double ))
547549 precond_ndarray = d[' precond' ]
548550 precond_shape = (node_ndarray.shape[0 ], self .n_outputs)
549551 if (precond_ndarray.shape != precond_shape or
550552 not precond_ndarray.flags.c_contiguous or
551553 precond_ndarray.dtype != np.float64):
552- raise ValueError (' Did not recognise loaded array layout' )
554+ raise ValueError (' Did not recognise loaded array layout for `precond_ndarray` ' )
553555 precond = memcpy(self .precond, (< np.ndarray> precond_ndarray).data,
554556 self .capacity * self .precond_stride * sizeof(double ))
555557
@@ -917,7 +919,7 @@ cdef class Tree:
917919 cdef np.npy_intp shape[3 ]
918920 shape[0 ] = < np.npy_intp> self .node_count
919921 shape[1 ] = < np.npy_intp> self .n_outputs
920- shape[2 ] = 1
922+ shape[2 ] = < np.npy_intp > self .max_n_classes
921923 cdef np.ndarray arr
922924 arr = np.PyArray_SimpleNewFromData(3 , shape, np.NPY_DOUBLE, self .value)
923925 Py_INCREF(self )
0 commit comments