@@ -580,13 +580,13 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
580580 pad_m = bm - local_m
581581
582582 if pad_k > 0 or pad_m > 0 :
583- x_block = np .pad (x_block , [(0 , pad_k ), (0 , pad_m )], mode = 'constant' )
583+ x_block = ncp .pad (x_block , [(0 , pad_k ), (0 , pad_m )], mode = 'constant' )
584584
585- Y_local = np .zeros ((self .A .shape [0 ], bm ),dtype = output_dtype )
585+ Y_local = ncp .zeros ((self .A .shape [0 ], bm ),dtype = output_dtype )
586586
587587 for k in range (self ._P_prime ):
588- Atemp = self .A .copy () if self ._col_id == k else np .empty_like (self .A )
589- Xtemp = x_block .copy () if self ._row_id == k else np .empty_like (x_block )
588+ Atemp = self .A .copy () if self ._col_id == k else ncp .empty_like (self .A )
589+ Xtemp = x_block .copy () if self ._row_id == k else ncp .empty_like (x_block )
590590 self ._row_comm .Bcast (Atemp , root = k )
591591 self ._col_comm .Bcast (Xtemp , root = k )
592592 Y_local += ncp .dot (Atemp , Xtemp )
@@ -646,14 +646,14 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
646646 pad_m = bm - local_m
647647
648648 if pad_n > 0 or pad_m > 0 :
649- x_block = np .pad (x_block , [(0 , pad_n ), (0 , pad_m )], mode = 'constant' )
649+ x_block = ncp .pad (x_block , [(0 , pad_n ), (0 , pad_m )], mode = 'constant' )
650650
651651 A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
652- Y_local = np .zeros ((self .A .shape [1 ], bm ), dtype = output_dtype )
652+ Y_local = ncp .zeros ((self .A .shape [1 ], bm ), dtype = output_dtype )
653653
654654 for k in range (self ._P_prime ):
655655 requests = []
656- ATtemp = np .empty_like (A_local )
656+ ATtemp = ncp .empty_like (A_local )
657657 srcA = k * self ._P_prime + self ._row_id
658658 tagA = (100 + k ) * 1000 + self .rank
659659 requests .append (self .base_comm .Irecv (ATtemp , source = srcA , tag = tagA ))
@@ -663,7 +663,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
663663 destA = fixed_col * self ._P_prime + moving_col
664664 tagA = (100 + k ) * 1000 + destA
665665 requests .append (self .base_comm .Isend (A_local , dest = destA , tag = tagA ))
666- Xtemp = x_block .copy () if self ._row_id == k else np .empty_like (x_block )
666+ Xtemp = x_block .copy () if self ._row_id == k else ncp .empty_like (x_block )
667667 requests .append (self ._col_comm .Ibcast (Xtemp , root = k ))
668668 MPI .Request .Waitall (requests )
669669 Y_local += ncp .dot (ATtemp , Xtemp )
0 commit comments