Skip to content

Commit 8e75175

Browse files
committed
BUG: Ensure traverse info copy works also if in-place (used by nditer)
Also removes a comment that was always fishy: we should have used this function, and we did, except for a funny little in-place bug!
1 parent 8fdce22 commit 8e75175

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

numpy/_core/src/multiarray/dtype_traversal.c

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -398,13 +398,6 @@ subarray_traverse_data_free(NpyAuxData *data)
398398
}
399399

400400

401-
/*
402-
* We seem to be neither using nor exposing this right now, so leave it NULL.
403-
* (The implementation below should be functional.)
404-
*/
405-
#define subarray_traverse_data_clone NULL
406-
407-
#ifndef subarray_traverse_data_clone
408401
/* traverse data copy function */
409402
static NpyAuxData *
410403
subarray_traverse_data_clone(NpyAuxData *data)
@@ -426,7 +419,6 @@ subarray_traverse_data_clone(NpyAuxData *data)
426419

427420
return (NpyAuxData *)newdata;
428421
}
429-
#endif
430422

431423

432424
static int
@@ -469,7 +461,7 @@ get_subarray_traverse_func(
469461

470462
auxdata->count = size;
471463
auxdata->base.free = &subarray_traverse_data_free;
472-
auxdata->base.clone = subarray_traverse_data_clone;
464+
auxdata->base.clone = &subarray_traverse_data_clone;
473465

474466
if (get_traverse_func(
475467
traverse_context, dtype, aligned,

numpy/_core/src/multiarray/dtype_traversal.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,22 @@ static inline int
6969
NPY_traverse_info_copy(
7070
NPY_traverse_info *traverse_info, NPY_traverse_info *original)
7171
{
72-
traverse_info->func = NULL;
72+
/* Note that original may be identical to traverse_info! */
7373
if (original->func == NULL) {
7474
/* Allow copying also of unused clear info */
75+
traverse_info->func = NULL;
7576
return 0;
7677
}
77-
traverse_info->auxdata = NULL;
7878
if (original->auxdata != NULL) {
7979
traverse_info->auxdata = NPY_AUXDATA_CLONE(original->auxdata);
8080
if (traverse_info->auxdata == NULL) {
81+
traverse_info->func = NULL;
8182
return -1;
8283
}
8384
}
85+
else {
86+
traverse_info->auxdata = NULL;
87+
}
8488
Py_INCREF(original->descr);
8589
traverse_info->descr = original->descr;
8690
traverse_info->func = original->func;

0 commit comments

Comments
 (0)