@@ -146,24 +146,27 @@ def cg_solve(A, b):
146
146
e_p = dpctl .SyclEvent ()
147
147
e_x = dpctl .SyclEvent ()
148
148
for i in range (max_iters ):
149
- he_dot , e_dot = sycl_gemm .gemv (
150
- exec_queue , A , p , Ap , depends = [e_p ]
151
- ) # Ap = A @ p
149
+ # Ap = A @ p
150
+ he_dot , e_dot = sycl_gemm .gemv (exec_queue , A , p , Ap , depends = [e_p ])
152
151
all_host_tasks .append (he_dot )
153
- alpha = rsold / sycl_gemm .dot_blocking ( # alpha = rsold / dot(p, Ap)
152
+ # alpha = rsold / dot(p, Ap)
153
+ alpha = rsold / sycl_gemm .dot_blocking (
154
154
exec_queue , p , Ap , depends = [e_dot ]
155
155
)
156
+ # x = x + alpha * p
156
157
he1_axpby , e1_axpby = sycl_gemm .axpby_inplace (
157
158
exec_queue , alpha , p , 1 , x , depends = [e_p , e_x ]
158
- ) # x = x + alpha * p
159
+ )
159
160
all_host_tasks .append (he1_axpby )
160
161
e_x = e1_axpby
161
162
163
+ # r = r - alpha * Ap
162
164
he2_axpby , e2_axpby = sycl_gemm .axpby_inplace (
163
165
exec_queue , - alpha , Ap , 1 , r , depends = [e_p ]
164
- ) # r = r - alpha * Ap
166
+ )
165
167
all_host_tasks .append (he2_axpby )
166
168
169
+ # rsnew = dot(r, r)
167
170
rsnew = sycl_gemm .norm_squared_blocking (
168
171
exec_queue , r , depends = [e2_axpby ]
169
172
)
@@ -173,9 +176,10 @@ def cg_solve(A, b):
173
176
break
174
177
beta = rsnew / rsold
175
178
179
+ # p = r + beta * p
176
180
he3_axpby , e3_axpby = sycl_gemm .axpby_inplace (
177
181
exec_queue , 1 , r , beta , p , depends = [e1_axpby , e2_axpby ]
178
- ) # p = r + beta * p
182
+ )
179
183
180
184
rsold = rsnew
181
185
all_host_tasks .append (he3_axpby )
0 commit comments