@@ -180,7 +180,7 @@ use crate::dtype::Element;
180
180
use crate :: error:: { BorrowError , NotContiguousError } ;
181
181
use crate :: npyffi:: { self , PyArrayObject , NPY_ARRAY_WRITEABLE } ;
182
182
183
- #[ derive( PartialEq , Eq , Hash ) ]
183
+ #[ derive( Clone , Copy , PartialEq , Eq , Hash ) ]
184
184
struct BorrowKey {
185
185
/// exclusive range of lowest and highest address covered by array
186
186
range : ( usize , usize ) ,
@@ -375,10 +375,16 @@ static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
375
375
/// i.e. that only shared references into the interior of the array can be created safely.
376
376
///
377
377
/// See the [module-level documentation](self) for more.
378
- pub struct PyReadonlyArray < ' py , T , D > ( & ' py PyArray < T , D > )
378
+ #[ repr( C ) ]
379
+ pub struct PyReadonlyArray < ' py , T , D >
379
380
where
380
381
T : Element ,
381
- D : Dimension ;
382
+ D : Dimension ,
383
+ {
384
+ array : & ' py PyArray < T , D > ,
385
+ address : usize ,
386
+ key : BorrowKey ,
387
+ }
382
388
383
389
/// Read-only borrow of a one-dimensional array.
384
390
pub type PyReadonlyArray1 < ' py , T > = PyReadonlyArray < ' py , T , Ix1 > ;
@@ -409,7 +415,7 @@ where
409
415
type Target = PyArray < T , D > ;
410
416
411
417
fn deref ( & self ) -> & Self :: Target {
412
- self . 0
418
+ self . array
413
419
}
414
420
}
415
421
@@ -426,27 +432,30 @@ where
426
432
D : Dimension ,
427
433
{
428
434
pub ( crate ) fn try_new ( array : & ' py PyArray < T , D > ) -> Result < Self , BorrowError > {
429
- let py = array. py ( ) ;
430
435
let address = base_address ( array) ;
431
436
let key = BorrowKey :: from_array ( array) ;
432
437
433
- BORROW_FLAGS . acquire ( py , address, key) ?;
438
+ BORROW_FLAGS . acquire ( array . py ( ) , address, key) ?;
434
439
435
- Ok ( Self ( array) )
440
+ Ok ( Self {
441
+ array,
442
+ address,
443
+ key,
444
+ } )
436
445
}
437
446
438
447
/// Provides an immutable array view of the interior of the NumPy array.
439
448
#[ inline( always) ]
440
449
pub fn as_array ( & self ) -> ArrayView < T , D > {
441
450
// SAFETY: Global borrow flags ensure aliasing discipline.
442
- unsafe { self . 0 . as_array ( ) }
451
+ unsafe { self . array . as_array ( ) }
443
452
}
444
453
445
454
/// Provide an immutable slice view of the interior of the NumPy array if it is contiguous.
446
455
#[ inline( always) ]
447
456
pub fn as_slice ( & self ) -> Result < & [ T ] , NotContiguousError > {
448
457
// SAFETY: Global borrow flags ensure aliasing discipline.
449
- unsafe { self . 0 . as_slice ( ) }
458
+ unsafe { self . array . as_slice ( ) }
450
459
}
451
460
452
461
/// Provide an immutable reference to an element of the NumPy array if the index is within bounds.
@@ -455,7 +464,7 @@ where
455
464
where
456
465
I : NpyIndex < Dim = D > ,
457
466
{
458
- unsafe { self . 0 . get ( index) }
467
+ unsafe { self . array . get ( index) }
459
468
}
460
469
}
461
470
@@ -465,7 +474,15 @@ where
465
474
D : Dimension ,
466
475
{
467
476
fn clone ( & self ) -> Self {
468
- Self :: try_new ( self . 0 ) . unwrap ( )
477
+ BORROW_FLAGS
478
+ . acquire ( self . array . py ( ) , self . address , self . key )
479
+ . unwrap ( ) ;
480
+
481
+ Self {
482
+ array : self . array ,
483
+ address : self . address ,
484
+ key : self . key ,
485
+ }
469
486
}
470
487
}
471
488
@@ -475,11 +492,7 @@ where
475
492
D : Dimension ,
476
493
{
477
494
fn drop ( & mut self ) {
478
- let py = self . 0 . py ( ) ;
479
- let address = base_address ( self . 0 ) ;
480
- let key = BorrowKey :: from_array ( self . 0 ) ;
481
-
482
- BORROW_FLAGS . release ( py, address, key) ;
495
+ BORROW_FLAGS . release ( self . array . py ( ) , self . address , self . key ) ;
483
496
}
484
497
}
485
498
@@ -505,10 +518,16 @@ where
505
518
/// i.e. that only a single exclusive reference into the interior of the array can be created safely.
506
519
///
507
520
/// See the [module-level documentation](self) for more.
508
- pub struct PyReadwriteArray < ' py , T , D > ( & ' py PyArray < T , D > )
521
+ #[ repr( C ) ]
522
+ pub struct PyReadwriteArray < ' py , T , D >
509
523
where
510
524
T : Element ,
511
- D : Dimension ;
525
+ D : Dimension ,
526
+ {
527
+ array : & ' py PyArray < T , D > ,
528
+ address : usize ,
529
+ key : BorrowKey ,
530
+ }
512
531
513
532
/// Read-write borrow of a one-dimensional array.
514
533
pub type PyReadwriteArray1 < ' py , T > = PyReadwriteArray < ' py , T , Ix1 > ;
@@ -561,27 +580,30 @@ where
561
580
return Err ( BorrowError :: NotWriteable ) ;
562
581
}
563
582
564
- let py = array. py ( ) ;
565
583
let address = base_address ( array) ;
566
584
let key = BorrowKey :: from_array ( array) ;
567
585
568
- BORROW_FLAGS . acquire_mut ( py , address, key) ?;
586
+ BORROW_FLAGS . acquire_mut ( array . py ( ) , address, key) ?;
569
587
570
- Ok ( Self ( array) )
588
+ Ok ( Self {
589
+ array,
590
+ address,
591
+ key,
592
+ } )
571
593
}
572
594
573
595
/// Provides a mutable array view of the interior of the NumPy array.
574
596
#[ inline( always) ]
575
597
pub fn as_array_mut ( & mut self ) -> ArrayViewMut < T , D > {
576
598
// SAFETY: Global borrow flags ensure aliasing discipline.
577
- unsafe { self . 0 . as_array_mut ( ) }
599
+ unsafe { self . array . as_array_mut ( ) }
578
600
}
579
601
580
602
/// Provide a mutable slice view of the interior of the NumPy array if it is contiguous.
581
603
#[ inline( always) ]
582
604
pub fn as_slice_mut ( & mut self ) -> Result < & mut [ T ] , NotContiguousError > {
583
605
// SAFETY: Global borrow flags ensure aliasing discipline.
584
- unsafe { self . 0 . as_slice_mut ( ) }
606
+ unsafe { self . array . as_slice_mut ( ) }
585
607
}
586
608
587
609
/// Provide a mutable reference to an element of the NumPy array if the index is within bounds.
@@ -590,7 +612,7 @@ where
590
612
where
591
613
I : NpyIndex < Dim = D > ,
592
614
{
593
- unsafe { self . 0 . get_mut ( index) }
615
+ unsafe { self . array . get_mut ( index) }
594
616
}
595
617
}
596
618
@@ -616,23 +638,16 @@ where
616
638
/// });
617
639
/// ```
618
640
pub fn resize ( self , new_elems : usize ) -> PyResult < Self > {
619
- let py = self . 0 . py ( ) ;
620
- let address = base_address ( self . 0 ) ;
621
- let key = BorrowKey :: from_array ( self . 0 ) ;
622
-
623
- BORROW_FLAGS . release_mut ( py, address, key) ;
641
+ let array = self . array ;
624
642
625
643
// SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
626
644
unsafe {
627
- self . 0 . resize ( new_elems) ?;
645
+ array . resize ( new_elems) ?;
628
646
}
629
647
630
- let address = base_address ( self . 0 ) ;
631
- let key = BorrowKey :: from_array ( self . 0 ) ;
632
-
633
- BORROW_FLAGS . acquire_mut ( py, address, key) ?;
648
+ drop ( self ) ;
634
649
635
- Ok ( self )
650
+ Ok ( Self :: try_new ( array ) . unwrap ( ) )
636
651
}
637
652
}
638
653
@@ -642,11 +657,7 @@ where
642
657
D : Dimension ,
643
658
{
644
659
fn drop ( & mut self ) {
645
- let py = self . 0 . py ( ) ;
646
- let address = base_address ( self . 0 ) ;
647
- let key = BorrowKey :: from_array ( self . 0 ) ;
648
-
649
- BORROW_FLAGS . release_mut ( py, address, key) ;
660
+ BORROW_FLAGS . release_mut ( self . array . py ( ) , self . address , self . key ) ;
650
661
}
651
662
}
652
663
@@ -1275,4 +1286,43 @@ mod tests {
1275
1286
}
1276
1287
} ) ;
1277
1288
}
1289
+
1290
+ #[ test]
1291
+ #[ should_panic( expected = "AlreadyBorrowed" ) ]
1292
+ fn cannot_clone_exclusive_borrow_via_deref ( ) {
1293
+ Python :: with_gil ( |py| {
1294
+ let array = PyArray :: < f64 , _ > :: zeros ( py, ( 3 , 2 , 1 ) , false ) ;
1295
+
1296
+ let exclusive = array. readwrite ( ) ;
1297
+ let _shared = exclusive. clone ( ) ;
1298
+ } ) ;
1299
+ }
1300
+
1301
+ #[ test]
1302
+ fn failed_resize_does_not_double_release ( ) {
1303
+ Python :: with_gil ( |py| {
1304
+ let array = PyArray :: < f64 , _ > :: zeros ( py, 10 , false ) ;
1305
+
1306
+ // The view will make the internal reference check of `PyArray_Resize` fail.
1307
+ let locals = [ ( "array" , array) ] . into_py_dict ( py) ;
1308
+ let _view = py
1309
+ . eval ( "array[:]" , None , Some ( locals) )
1310
+ . unwrap ( )
1311
+ . downcast :: < PyArray1 < f64 > > ( )
1312
+ . unwrap ( ) ;
1313
+
1314
+ let exclusive = array. readwrite ( ) ;
1315
+ assert ! ( exclusive. resize( 100 ) . is_err( ) ) ;
1316
+ } ) ;
1317
+ }
1318
+
1319
+ #[ test]
1320
+ fn ineffective_resize_does_not_conflict ( ) {
1321
+ Python :: with_gil ( |py| {
1322
+ let array = PyArray :: < f64 , _ > :: zeros ( py, 10 , false ) ;
1323
+
1324
+ let exclusive = array. readwrite ( ) ;
1325
+ assert ! ( exclusive. resize( 10 ) . is_ok( ) ) ;
1326
+ } ) ;
1327
+ }
1278
1328
}
0 commit comments