@@ -16,6 +16,7 @@ def __init__(
1616 self ,
1717 A : NDArray ,
1818 N : int ,
19+ saveAt : bool = False ,
1920 base_comm : MPI .Comm = MPI .COMM_WORLD ,
2021 dtype : DTypeLike = "float64" ,
2122 ) -> None :
@@ -25,113 +26,91 @@ def __init__(
2526 # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
2627 self ._P_prime = int (math .ceil (math .sqrt (size )))
2728 self ._C = int (math .ceil (size / self ._P_prime ))
28- if self ._P_prime * self ._C < size :
29+ if self ._P_prime * self ._C != size :
2930 raise Exception ("Number of Procs must be a square number" )
3031
3132 # Compute this process's group and layer indices
3233 self ._group_id = rank % self ._P_prime
3334 self ._layer_id = rank // self ._P_prime
3435
3536 # Split communicators by layer (rows) and by group (columns)
36- self .base_comm = base_comm
37+ self .base_comm = base_comm
3738 self ._layer_comm = base_comm .Split (color = self ._layer_id , key = self ._group_id )
3839 self ._group_comm = base_comm .Split (color = self ._group_id , key = self ._layer_id )
3940 self .A = A .astype (np .dtype (dtype ))
41+ if saveAt : self .At = A .T .conj ()
4042
4143 self .M = self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
4244 self .K = A .shape [1 ]
4345 self .N = N
4446
4547 # Determine how many columns each group holds
4648 block_cols = int (math .ceil (self .N / self ._P_prime ))
47- local_col_start = self ._group_id * block_cols
48- local_col_end = min (self .N , local_col_start + block_cols )
49- local_ncols = local_col_end - local_col_start
49+ blk_rows = int (math .ceil (self .M / self ._P_prime ))
5050
51- # Sum up the total number of input columns across all processes
52- total_ncols = base_comm .allreduce (local_ncols , op = MPI .SUM )
53- self .dims = (self .K , total_ncols )
51+ self ._row_start = self ._group_id * blk_rows
52+ self ._row_end = min (self .M , self ._row_start + blk_rows )
53+
54+ self ._col_start = self ._layer_id * block_cols
55+ self ._col_end = min (self .N , self ._col_start + block_cols )
5456
55- # Recompute how many output columns each layer holds
56- layer_col_start = self ._layer_id * block_cols
57- layer_col_end = min (self .N , layer_col_start + block_cols )
58- layer_ncols = layer_col_end - layer_col_start
59- total_layer_cols = self .base_comm .allreduce (layer_ncols , op = MPI .SUM )
57+ self ._local_ncols = self ._col_end - self ._col_start
58+ self ._rank_col_lens = self .base_comm .allgather (self ._local_ncols )
59+ total_ncols = np .sum (self ._rank_col_lens )
6060
61- self .dimsd = (self .M , total_layer_cols )
61+ self .dims = (self .K , total_ncols )
62+ self .dimsd = (self .M , total_ncols )
6263 shape = (int (np .prod (self .dimsd )), int (np .prod (self .dims )))
6364 super ().__init__ (shape = shape , dtype = np .dtype (dtype ), base_comm = base_comm )
64-
65+
6566 def _matvec (self , x : DistributedArray ) -> DistributedArray :
6667 ncp = get_module (x .engine )
6768 if x .partition != Partition .SCATTER :
6869 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
69- blk_cols = int (math .ceil (self .N / self ._P_prime ))
70- col_start = self ._layer_id * blk_cols
71- col_end = min (self .N , col_start + blk_cols )
72- my_own_cols = max (0 , col_end - col_start )
73- x = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
74- x = x .astype (self .dtype )
75-
76- B_block = self ._layer_comm .bcast (x if self ._group_id == self ._layer_id else None , root = self ._layer_id )
77- C_local = ncp .vstack (
70+
71+ my_own_cols = self ._rank_col_lens [self .rank ]
72+ x_arr = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
73+ x_arr = x_arr .astype (self .dtype )
74+
75+ X_local = self ._layer_comm .bcast (x_arr if self ._group_id == self ._layer_id else None , root = self ._layer_id )
76+ Y_local = ncp .vstack (
7877 self ._layer_comm .allgather (
79- ncp .matmul (self .A , B_block )
78+ ncp .matmul (self .A , X_local )
8079 )
8180 )
8281
83- layer_col_start = self ._layer_id * blk_cols
84- layer_col_end = min (self .N , layer_col_start + blk_cols )
85- layer_ncols = max (0 , layer_col_end - layer_col_start )
86- layer_col_lens = self .base_comm .allgather (layer_ncols )
87- mask = [i // self ._P_prime for i in range (self .size )]
88-
89- y = DistributedArray (global_shape = (self .M * self .dimsd [1 ]),
90- local_shapes = [(self .M * c ) for c in layer_col_lens ],
91- mask = mask ,
82+ y = DistributedArray (global_shape = (self .M * self .dimsd [1 ]),
83+ local_shapes = [(self .M * c ) for c in self ._rank_col_lens ],
84+ mask = x .mask ,
9285 partition = Partition .SCATTER ,
9386 dtype = self .dtype )
94- y [:] = C_local .flatten ()
87+ y [:] = Y_local .flatten ()
9588 return y
9689
9790 def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
9891 ncp = get_module (x .engine )
9992 if x .partition != Partition .SCATTER :
10093 raise ValueError (f"x should have partition={ Partition .SCATTER } . Got { x .partition } instead." )
10194
102- # Determine local column block for this layer
103- blk_cols = int (math .ceil (self .N / self ._P_prime ))
104- layer_col_start = self ._layer_id * blk_cols
105- layer_col_end = min (self .N , layer_col_start + blk_cols )
106- layer_ncols = layer_col_end - layer_col_start
107- layer_col_lens = self .base_comm .allgather (layer_ncols )
108- x = x .local_array .reshape ((self .M , layer_ncols )).astype (self .dtype )
109-
110- # Determine local row block for this process group
111- blk_rows = int (math .ceil (self .M / self ._P_prime ))
112- row_start = self ._group_id * blk_rows
113- row_end = min (self .M , row_start + blk_rows )
114-
115- B_tile = x [row_start :row_end , :].astype (self .dtype )
116- A_local = self .A .T .conj ().astype (self .dtype )
117-
118- m , b = A_local .shape
119- pad = (- m ) % self ._P_prime
120- r = (m + pad ) // self ._P_prime
121- A_pad = np .pad (A_local , ((0 , pad ), (0 , 0 )), mode = 'constant' , constant_values = self .dtype .type (0.0 ))
95+ x_arr = x .local_array .reshape ((self .M , self ._local_ncols )).astype (self .dtype )
96+ X_tile = x_arr [self ._row_start :self ._row_end , :]
97+
98+ A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
99+ m , b = A_local .shape
100+ pad = (- m ) % self ._P_prime
101+ r = (m + pad ) // self ._P_prime
102+ A_pad = np .pad (A_local , ((0 , pad ), (0 , 0 )), mode = 'constant' , constant_values = self .dtype .type (0.0 ))
122103 A_batch = A_pad .reshape (self ._P_prime , r , b )
123104
124- # Perform local matmul and unpad
125- Y_batch = ncp .matmul (A_batch , B_tile ).astype (self .dtype )
126- Y_pad = Y_batch .reshape (r * self ._P_prime , - 1 )
105+ Y_batch = ncp .matmul (A_batch , X_tile )
106+ Y_pad = Y_batch .reshape (r * self ._P_prime , - 1 )
127107 y_local = Y_pad [:m , :]
128108 y_layer = self ._layer_comm .allreduce (y_local , op = MPI .SUM )
129109
130- mask = [i // self ._P_prime for i in range (self .size )]
131110 y = DistributedArray (
132111 global_shape = (self .K * self .dimsd [1 ]),
133- local_shapes = [self .K * c for c in layer_col_lens ],
134- mask = mask ,
112+ local_shapes = [self .K * c for c in self . _rank_col_lens ],
113+ mask = x . mask ,
135114 partition = Partition .SCATTER ,
136115 dtype = self .dtype ,
137116 )
0 commit comments