@@ -180,7 +180,7 @@ use crate::dtype::Element;
180
180
use crate :: error:: { BorrowError , NotContiguousError } ;
181
181
use crate :: npyffi:: { self , PyArrayObject , NPY_ARRAY_WRITEABLE } ;
182
182
183
- #[ derive( PartialEq , Eq , Hash ) ]
183
+ #[ derive( Clone , Copy , PartialEq , Eq , Hash ) ]
184
184
struct BorrowKey {
185
185
/// exclusive range of lowest and highest address covered by array
186
186
range : ( usize , usize ) ,
@@ -199,7 +199,7 @@ impl BorrowKey {
199
199
let range = data_range ( array) ;
200
200
201
201
let data_ptr = array. data ( ) as usize ;
202
- let gcd_strides = reduce ( array. strides ( ) . iter ( ) . copied ( ) , gcd ) . unwrap_or ( 1 ) ;
202
+ let gcd_strides = gcd_strides ( array. strides ( ) ) ;
203
203
204
204
Self {
205
205
range,
@@ -252,16 +252,9 @@ impl BorrowFlags {
252
252
( * self . 0 . get ( ) ) . get_or_insert_with ( AHashMap :: new)
253
253
}
254
254
255
- fn acquire < T , D > ( & self , array : & PyArray < T , D > ) -> Result < ( ) , BorrowError >
256
- where
257
- T : Element ,
258
- D : Dimension ,
259
- {
260
- let address = base_address ( array) ;
261
- let key = BorrowKey :: from_array ( array) ;
262
-
263
- // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
264
- // and we are not calling into user code which might re-enter this function.
255
+ fn acquire ( & self , _py : Python , address : usize , key : BorrowKey ) -> Result < ( ) , BorrowError > {
256
+ // SAFETY: Having `_py` implies holding the GIL and
257
+ // we are not calling into user code which might re-enter this function.
265
258
let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
266
259
267
260
match borrow_flags. entry ( address) {
@@ -302,16 +295,9 @@ impl BorrowFlags {
302
295
Ok ( ( ) )
303
296
}
304
297
305
- fn release < T , D > ( & self , array : & PyArray < T , D > )
306
- where
307
- T : Element ,
308
- D : Dimension ,
309
- {
310
- let address = base_address ( array) ;
311
- let key = BorrowKey :: from_array ( array) ;
312
-
313
- // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
314
- // and we are not calling into user code which might re-enter this function.
298
+ fn release ( & self , _py : Python , address : usize , key : BorrowKey ) {
299
+ // SAFETY: Having `_py` implies holding the GIL and
300
+ // we are not calling into user code which might re-enter this function.
315
301
let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
316
302
317
303
let same_base_arrays = borrow_flags. get_mut ( & address) . unwrap ( ) ;
@@ -329,16 +315,9 @@ impl BorrowFlags {
329
315
}
330
316
}
331
317
332
- fn acquire_mut < T , D > ( & self , array : & PyArray < T , D > ) -> Result < ( ) , BorrowError >
333
- where
334
- T : Element ,
335
- D : Dimension ,
336
- {
337
- let address = base_address ( array) ;
338
- let key = BorrowKey :: from_array ( array) ;
339
-
340
- // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
341
- // and we are not calling into user code which might re-enter this function.
318
+ fn acquire_mut ( & self , _py : Python , address : usize , key : BorrowKey ) -> Result < ( ) , BorrowError > {
319
+ // SAFETY: Having `_py` implies holding the GIL and
320
+ // we are not calling into user code which might re-enter this function.
342
321
let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
343
322
344
323
match borrow_flags. entry ( address) {
@@ -373,16 +352,9 @@ impl BorrowFlags {
373
352
Ok ( ( ) )
374
353
}
375
354
376
- fn release_mut < T , D > ( & self , array : & PyArray < T , D > )
377
- where
378
- T : Element ,
379
- D : Dimension ,
380
- {
381
- let address = base_address ( array) ;
382
- let key = BorrowKey :: from_array ( array) ;
383
-
384
- // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
385
- // and we are not calling into user code which might re-enter this function.
355
+ fn release_mut ( & self , _py : Python , address : usize , key : BorrowKey ) {
356
+ // SAFETY: Having `_py` implies holding the GIL and
357
+ // we are not calling into user code which might re-enter this function.
386
358
let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
387
359
388
360
let same_base_arrays = borrow_flags. get_mut ( & address) . unwrap ( ) ;
@@ -403,10 +375,16 @@ static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
403
375
/// i.e. that only shared references into the interior of the array can be created safely.
404
376
///
405
377
/// See the [module-level documentation](self) for more.
406
- pub struct PyReadonlyArray < ' py , T , D > ( & ' py PyArray < T , D > )
378
+ #[ repr( C ) ]
379
+ pub struct PyReadonlyArray < ' py , T , D >
407
380
where
408
381
T : Element ,
409
- D : Dimension ;
382
+ D : Dimension ,
383
+ {
384
+ array : & ' py PyArray < T , D > ,
385
+ address : usize ,
386
+ key : BorrowKey ,
387
+ }
410
388
411
389
/// Read-only borrow of a one-dimensional array.
412
390
pub type PyReadonlyArray1 < ' py , T > = PyReadonlyArray < ' py , T , Ix1 > ;
@@ -437,7 +415,7 @@ where
437
415
type Target = PyArray < T , D > ;
438
416
439
417
fn deref ( & self ) -> & Self :: Target {
440
- self . 0
418
+ self . array
441
419
}
442
420
}
443
421
@@ -454,23 +432,30 @@ where
454
432
D : Dimension ,
455
433
{
456
434
pub ( crate ) fn try_new ( array : & ' py PyArray < T , D > ) -> Result < Self , BorrowError > {
457
- BORROW_FLAGS . acquire ( array) ?;
435
+ let address = base_address ( array) ;
436
+ let key = BorrowKey :: from_array ( array) ;
458
437
459
- Ok ( Self ( array) )
438
+ BORROW_FLAGS . acquire ( array. py ( ) , address, key) ?;
439
+
440
+ Ok ( Self {
441
+ array,
442
+ address,
443
+ key,
444
+ } )
460
445
}
461
446
462
447
/// Provides an immutable array view of the interior of the NumPy array.
463
448
#[ inline( always) ]
464
449
pub fn as_array ( & self ) -> ArrayView < T , D > {
465
450
// SAFETY: Global borrow flags ensure aliasing discipline.
466
- unsafe { self . 0 . as_array ( ) }
451
+ unsafe { self . array . as_array ( ) }
467
452
}
468
453
469
454
/// Provide an immutable slice view of the interior of the NumPy array if it is contiguous.
470
455
#[ inline( always) ]
471
456
pub fn as_slice ( & self ) -> Result < & [ T ] , NotContiguousError > {
472
457
// SAFETY: Global borrow flags ensure aliasing discipline.
473
- unsafe { self . 0 . as_slice ( ) }
458
+ unsafe { self . array . as_slice ( ) }
474
459
}
475
460
476
461
/// Provide an immutable reference to an element of the NumPy array if the index is within bounds.
@@ -479,7 +464,7 @@ where
479
464
where
480
465
I : NpyIndex < Dim = D > ,
481
466
{
482
- unsafe { self . 0 . get ( index) }
467
+ unsafe { self . array . get ( index) }
483
468
}
484
469
}
485
470
@@ -489,7 +474,15 @@ where
489
474
D : Dimension ,
490
475
{
491
476
fn clone ( & self ) -> Self {
492
- Self :: try_new ( self . 0 ) . unwrap ( )
477
+ BORROW_FLAGS
478
+ . acquire ( self . array . py ( ) , self . address , self . key )
479
+ . unwrap ( ) ;
480
+
481
+ Self {
482
+ array : self . array ,
483
+ address : self . address ,
484
+ key : self . key ,
485
+ }
493
486
}
494
487
}
495
488
@@ -499,7 +492,7 @@ where
499
492
D : Dimension ,
500
493
{
501
494
fn drop ( & mut self ) {
502
- BORROW_FLAGS . release ( self . 0 ) ;
495
+ BORROW_FLAGS . release ( self . array . py ( ) , self . address , self . key ) ;
503
496
}
504
497
}
505
498
@@ -525,10 +518,16 @@ where
525
518
/// i.e. that only a single exclusive reference into the interior of the array can be created safely.
526
519
///
527
520
/// See the [module-level documentation](self) for more.
528
- pub struct PyReadwriteArray < ' py , T , D > ( & ' py PyArray < T , D > )
521
+ #[ repr( C ) ]
522
+ pub struct PyReadwriteArray < ' py , T , D >
529
523
where
530
524
T : Element ,
531
- D : Dimension ;
525
+ D : Dimension ,
526
+ {
527
+ array : & ' py PyArray < T , D > ,
528
+ address : usize ,
529
+ key : BorrowKey ,
530
+ }
532
531
533
532
/// Read-write borrow of a one-dimensional array.
534
533
pub type PyReadwriteArray1 < ' py , T > = PyReadwriteArray < ' py , T , Ix1 > ;
@@ -581,23 +580,30 @@ where
581
580
return Err ( BorrowError :: NotWriteable ) ;
582
581
}
583
582
584
- BORROW_FLAGS . acquire_mut ( array) ?;
583
+ let address = base_address ( array) ;
584
+ let key = BorrowKey :: from_array ( array) ;
585
585
586
- Ok ( Self ( array) )
586
+ BORROW_FLAGS . acquire_mut ( array. py ( ) , address, key) ?;
587
+
588
+ Ok ( Self {
589
+ array,
590
+ address,
591
+ key,
592
+ } )
587
593
}
588
594
589
595
/// Provides a mutable array view of the interior of the NumPy array.
590
596
#[ inline( always) ]
591
597
pub fn as_array_mut ( & mut self ) -> ArrayViewMut < T , D > {
592
598
// SAFETY: Global borrow flags ensure aliasing discipline.
593
- unsafe { self . 0 . as_array_mut ( ) }
599
+ unsafe { self . array . as_array_mut ( ) }
594
600
}
595
601
596
602
/// Provide a mutable slice view of the interior of the NumPy array if it is contiguous.
597
603
#[ inline( always) ]
598
604
pub fn as_slice_mut ( & mut self ) -> Result < & mut [ T ] , NotContiguousError > {
599
605
// SAFETY: Global borrow flags ensure aliasing discipline.
600
- unsafe { self . 0 . as_slice_mut ( ) }
606
+ unsafe { self . array . as_slice_mut ( ) }
601
607
}
602
608
603
609
/// Provide a mutable reference to an element of the NumPy array if the index is within bounds.
@@ -606,7 +612,7 @@ where
606
612
where
607
613
I : NpyIndex < Dim = D > ,
608
614
{
609
- unsafe { self . 0 . get_mut ( index) }
615
+ unsafe { self . array . get_mut ( index) }
610
616
}
611
617
}
612
618
@@ -632,16 +638,16 @@ where
632
638
/// });
633
639
/// ```
634
640
pub fn resize ( self , new_elems : usize ) -> PyResult < Self > {
635
- BORROW_FLAGS . release_mut ( self . 0 ) ;
641
+ let array = self . array ;
636
642
637
643
// SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
638
644
unsafe {
639
- self . 0 . resize ( new_elems) ?;
645
+ array . resize ( new_elems) ?;
640
646
}
641
647
642
- BORROW_FLAGS . acquire_mut ( self . 0 ) ? ;
648
+ drop ( self ) ;
643
649
644
- Ok ( self )
650
+ Ok ( Self :: try_new ( array ) . unwrap ( ) )
645
651
}
646
652
}
647
653
@@ -651,7 +657,7 @@ where
651
657
D : Dimension ,
652
658
{
653
659
fn drop ( & mut self ) {
654
- BORROW_FLAGS . release_mut ( self . 0 ) ;
660
+ BORROW_FLAGS . release_mut ( self . array . py ( ) , self . address , self . key ) ;
655
661
}
656
662
}
657
663
@@ -726,6 +732,10 @@ where
726
732
)
727
733
}
728
734
735
+ fn gcd_strides ( strides : & [ isize ] ) -> isize {
736
+ reduce ( strides. iter ( ) . copied ( ) , gcd) . unwrap_or ( 1 )
737
+ }
738
+
729
739
// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
730
740
fn abs_diff ( lhs : usize , rhs : usize ) -> usize {
731
741
if lhs >= rhs {
@@ -1276,4 +1286,43 @@ mod tests {
1276
1286
}
1277
1287
} ) ;
1278
1288
}
1289
+
1290
+ #[ test]
1291
+ #[ should_panic( expected = "AlreadyBorrowed" ) ]
1292
+ fn cannot_clone_exclusive_borrow_via_deref ( ) {
1293
+ Python :: with_gil ( |py| {
1294
+ let array = PyArray :: < f64 , _ > :: zeros ( py, ( 3 , 2 , 1 ) , false ) ;
1295
+
1296
+ let exclusive = array. readwrite ( ) ;
1297
+ let _shared = exclusive. clone ( ) ;
1298
+ } ) ;
1299
+ }
1300
+
1301
+ #[ test]
1302
+ fn failed_resize_does_not_double_release ( ) {
1303
+ Python :: with_gil ( |py| {
1304
+ let array = PyArray :: < f64 , _ > :: zeros ( py, 10 , false ) ;
1305
+
1306
+ // The view will make the internal reference check of `PyArray_Resize` fail.
1307
+ let locals = [ ( "array" , array) ] . into_py_dict ( py) ;
1308
+ let _view = py
1309
+ . eval ( "array[:]" , None , Some ( locals) )
1310
+ . unwrap ( )
1311
+ . downcast :: < PyArray1 < f64 > > ( )
1312
+ . unwrap ( ) ;
1313
+
1314
+ let exclusive = array. readwrite ( ) ;
1315
+ assert ! ( exclusive. resize( 100 ) . is_err( ) ) ;
1316
+ } ) ;
1317
+ }
1318
+
1319
+ #[ test]
1320
+ fn ineffective_resize_does_not_conflict ( ) {
1321
+ Python :: with_gil ( |py| {
1322
+ let array = PyArray :: < f64 , _ > :: zeros ( py, 10 , false ) ;
1323
+
1324
+ let exclusive = array. readwrite ( ) ;
1325
+ assert ! ( exclusive. resize( 10 ) . is_ok( ) ) ;
1326
+ } ) ;
1327
+ }
1279
1328
}
0 commit comments