Skip to content

Commit 7992ed9

Browse files
committed
Broadcasting over BLAS1
1 parent 4ad1e49 commit 7992ed9

File tree

10 files changed

+74
-108
lines changed

10 files changed

+74
-108
lines changed

src/bicgstabl.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ function bicgstabl_iterator!(x, A, b, l::Int = 2;
4747
copy!(residual, b)
4848
else
4949
A_mul_B!(residual, A, x)
50-
@blas! residual -= one(T) * b
51-
@blas! residual *= -one(T)
50+
residual .= b .- residual
5251
mv_products += 1
5352
end
5453

@@ -89,10 +88,7 @@ function next(it::BiCGStabIterable, iteration::Int)
8988
β = ρ / it.σ
9089

9190
# us[:, 1 : j] .= rs[:, 1 : j] - β * us[:, 1 : j]
92-
for i = 1 : j
93-
@blas! view(it.us, :, i) *= -β
94-
@blas! view(it.us, :, i) += one(T) * view(it.rs, :, i)
95-
end
91+
it.us[:, 1 : j] .= it.rs[:, 1 : j] .- β .* it.us[:, 1 : j]
9692

9793
# us[:, j + 1] = Pl \ (A * us[:, j])
9894
next_u = view(it.us, :, j + 1)
@@ -102,18 +98,15 @@ function next(it::BiCGStabIterable, iteration::Int)
10298
it.σ = dot(it.r_shadow, next_u)
10399
α = ρ / it.σ
104100

105-
# rs[:, 1 : j] .= rs[:, 1 : j] - α * us[:, 2 : j + 1]
106-
for i = 1 : j
107-
@blas! view(it.rs, :, i) -= α * view(it.us, :, i + 1)
108-
end
101+
it.rs[:, 1 : j] .-= α .* it.us[:, 2 : j + 1]
109102

110103
# rs[:, j + 1] = Pl \ (A * rs[:, j])
111104
next_r = view(it.rs, :, j + 1)
112105
A_mul_B!(next_r, it.A , view(it.rs, :, j))
113106
A_ldiv_B!(it.Pl, next_r)
114107

115108
# x = x + α * us[:, 1]
116-
@blas! it.x += α * view(it.us, :, 1)
109+
it.x .+= α .* view(it.us, :, 1)
117110
end
118111

119112
# Bookkeeping

src/cg.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,15 @@ end
4343
function next(it::CGIterable, iteration::Int)
4444
# u := r + βu (almost an axpy)
4545
β = it.residual^2 / it.prev_residual^2
46-
@blas! it.u *= β
47-
@blas! it.u += one(eltype(it.u)) * it.r
46+
it.u .= it.r .+ β .* it.u
4847

4948
# c = A * u
5049
A_mul_B!(it.c, it.A, it.u)
5150
α = it.residual^2 / dot(it.u, it.c)
5251

5352
# Improve solution and residual
54-
@blas! it.x += α * it.u
55-
@blas! it.r -= α * it.c
53+
it.x .+= α .* it.u
54+
it.r .-= α .* it.c
5655

5756
it.prev_residual = it.residual
5857
it.residual = norm(it.r)
@@ -73,16 +72,15 @@ function next(it::PCGIterable, iteration::Int)
7372

7473
# u := c + βu (almost an axpy)
7574
β = it.ρ / ρ_prev
76-
@blas! it.u *= β
77-
@blas! it.u += one(eltype(it.u)) * it.c
75+
it.u .= it.c .+ β .* it.u
7876

7977
# c = A * u
8078
A_mul_B!(it.c, it.A, it.u)
8179
α = it.ρ / dot(it.u, it.c)
8280

8381
# Improve solution and residual
84-
@blas! it.x += α * it.u
85-
@blas! it.r -= α * it.c
82+
it.x .+= α .* it.u
83+
it.r .-= α .* it.c
8684

8785
it.residual = norm(it.r)
8886

@@ -110,7 +108,7 @@ function cg_iterator!(x, A, b, Pl = Identity();
110108
else
111109
mv_products = 1
112110
c = A * x
113-
@blas! r -= one(eltype(x)) * c
111+
r .-= c
114112
residual = norm(r)
115113
reltol = norm(b) * tol
116114
end

src/chebyshev.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,14 @@ function next(cheb::ChebyshevIterable, iteration::Int)
3737
else
3838
β = (cheb.λ_diff * cheb.α / 2) ^ 2
3939
cheb.α = inv(cheb.λ_avg - β)
40-
41-
# Almost an axpy u = c + βu
42-
scale!(cheb.u, β)
43-
@blas! cheb.u += one(T) * cheb.c
40+
cheb.u .= cheb.c .+ β .* cheb.c
4441
end
4542

4643
A_mul_B!(cheb.c, cheb.A, cheb.u)
4744
cheb.mv_products += 1
4845

49-
@blas! cheb.x += cheb.α * cheb.u
50-
@blas! cheb.r -= cheb.α * cheb.c
46+
cheb.x .+= cheb.α .* cheb.u
47+
cheb.r .-= cheb.α .* cheb.c
5148

5249
cheb.resnorm = norm(cheb.r)
5350

@@ -73,7 +70,7 @@ function chebyshev_iterable!(x, A, b, λmin::Real, λmax::Real;
7370
mv_products = 0
7471
else
7572
A_mul_B!(c, A, x)
76-
@blas! r -= one(T) * c
73+
r .-= c
7774
resnorm = norm(r)
7875
reltol = tol * norm(b)
7976
mv_products = 1

src/gmres.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,14 @@ function init!(arnoldi::ArnoldiDecomp{T}, x, b, Pl, Ax; initially_zero::Bool = f
221221
# Potentially save one MV product
222222
if !initially_zero
223223
A_mul_B!(Ax, arnoldi.A, x)
224-
@blas! first_col -= one(T) * Ax
224+
first_col .-= Ax
225225
end
226226

227227
A_ldiv_B!(Pl, first_col)
228228

229229
# Normalize
230230
β = norm(first_col)
231-
@blas! first_col *= one(T) / β
231+
first_col .*= inv(β)
232232
β
233233
end
234234

@@ -259,7 +259,7 @@ function update_solution!(x, y, arnoldi::ArnoldiDecomp{T}, Pr, k::Int, Ax) where
259259
# Computing x ← x + Pr \ (V * y) and use Ax as a work space
260260
A_mul_B!(Ax, view(arnoldi.V, :, 1 : k - 1), y)
261261
A_ldiv_B!(Pr, Ax)
262-
@blas! x += one(T) * Ax
262+
x .+= Ax
263263
end
264264

265265
function expand!(arnoldi::ArnoldiDecomp, Pl::Identity, Pr::Identity, k::Int, Ax)

src/idrs.jl

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -111,32 +111,26 @@ function idrs_method!(log::ConvergenceHistory, X, op, args, C::T,
111111
# Solve small system and make v orthogonal to P
112112

113113
c = LowerTriangular(M[k:s,k:s])\f[k:s]
114-
@blas! V = G[k]
115-
@blas! V *= c[1]
114+
V .= c[1] .* G[k]
115+
Q .= c[1] .* U[k]
116116

117-
@blas! Q = U[k]
118-
@blas! Q *= c[1]
119117
for i = k+1:s
120-
@blas! V += c[i-k+1]*G[i]
121-
@blas! Q += c[i-k+1]*U[i]
118+
V .+= c[i-k+1] .* G[i]
119+
Q .+= c[i-k+1] .* U[i]
122120
end
123121

124122
# Compute new U[:,k] and G[:,k], G[:,k] is in space G_j
123+
V .= R .- V
125124

126-
#V = R - V
127-
@blas! V *= -1.
128-
@blas! V += R
129-
130-
@blas! U[k] = Q
131-
@blas! U[k] += om*V
125+
U[k] .= Q .+ om .* V
132126
G[k] = op(U[k], args...)
133127

134128
# Bi-orthogonalise the new basis vectors
135129

136130
for i in 1:k-1
137131
alpha = vecdot(P[i],G[k])/M[i,i]
138-
@blas! G[k] += -alpha*G[i]
139-
@blas! U[k] += -alpha*U[i]
132+
G[k] .-= alpha .* G[i]
133+
U[k] .-= alpha .* U[i]
140134
end
141135

142136
# New column of M = P'*G (first k-1 entries are zero)
@@ -148,22 +142,17 @@ function idrs_method!(log::ConvergenceHistory, X, op, args, C::T,
148142
# Make r orthogonal to q_i, i = 1..k
149143

150144
beta = f[k]/M[k,k]
151-
@blas! R += -beta*G[k]
152-
@blas! X += beta*U[k]
145+
R .-= beta .* G[k]
146+
X .+= beta .* U[k]
153147

154148
normR = vecnorm(R)
155149
if smoothing
156-
# T_s = R_s - R
157-
@blas! T_s = R_s
158-
@blas! T_s += (-1.)*R
150+
T_s .= R_s .- R
159151

160152
gamma = vecdot(R_s, T_s)/vecdot(T_s, T_s)
161153

162-
@blas! R_s += -gamma*T_s
163-
# X_s = X_s - gamma*(X_s - X)
164-
@blas! T_s = X_s
165-
@blas! T_s += (-1.)*X
166-
@blas! X_s += -gamma*T_s
154+
R_s .-= gamma .* T_s
155+
X_s .-= gamma .* (X_s .- X)
167156

168157
normR = vecnorm(R_s)
169158
end
@@ -182,34 +171,29 @@ function idrs_method!(log::ConvergenceHistory, X, op, args, C::T,
182171

183172
# Now we have sufficient vectors in G_j to compute residual in G_j+1
184173
# Note: r is already perpendicular to P so v = r
185-
@blas! V = R
174+
copy!(V, R)
186175
Q = op(V, args...)::T
187176
om = omega(Q, R)
188-
@blas! R += -om*Q
189-
@blas! X += om*V
177+
R .-= om .* Q
178+
X .+= om .* V
190179

191180
normR = vecnorm(R)
192181
if smoothing
193-
# T_s = R_s - R
194-
@blas! T_s = R_s
195-
@blas! T_s += (-1.)*R
182+
T_s .= R_s .- R
196183

197184
gamma = vecdot(R_s, T_s)/vecdot(T_s, T_s)
198185

199-
@blas! R_s += -gamma*T_s
200-
# X_s = X_s - gamma*(X_s - X)
201-
@blas! T_s = X_s
202-
@blas! T_s += (-1.)*X
203-
@blas! X_s += -gamma*T_s
186+
R_s .-= gamma .* T_s
187+
X_s .-= gamma .* (X_s .- X)
204188

205189
normR = vecnorm(R_s)
206190
end
207191
iter += 1
208-
nextiter!(log,mvps=1)
192+
nextiter!(log, mvps=1)
209193
push!(log, :resnorm, normR)
210194
end
211195
if smoothing
212-
@blas! X = X_s
196+
copy!(X, X_s)
213197
end
214198
verbose && @printf("\n")
215199
setconv(log, 0<=normR<tol)

src/lsmr.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ function lsmr_method!(log::ConvergenceHistory, x, A, b, v, h, hbar;
109109
# form the first vectors u and v (satisfy β*u = b, α*v = A'u)
110110
u = A_mul_B!(-1, A, x, 1, b)
111111
β = norm(u)
112-
β > 0 && @blas! u *= inv(β)
112+
u .*= inv(β)
113113
Ac_mul_B!(1, A, u, 0, v)
114114
α = norm(v)
115-
α > 0 && @blas! v *= inv(α)
115+
v .*= inv(α)
116116

117117
log[:atol] = atol
118118
log[:btol] = btol
@@ -126,7 +126,7 @@ function lsmr_method!(log::ConvergenceHistory, x, A, b, v, h, hbar;
126126
cbar = one(Tr)
127127
sbar = zero(Tr)
128128

129-
@blas! h = v
129+
copy!(h, v)
130130
fill!(hbar, zero(Tr))
131131

132132
# Initialize variables for estimation of ||r||.
@@ -162,10 +162,10 @@ function lsmr_method!(log::ConvergenceHistory, x, A, b, v, h, hbar;
162162
β = norm(u)
163163
if β > 0
164164
log.mtvps+=1
165-
@blas! u *= inv(β)
165+
u .*= inv(β)
166166
Ac_mul_B!(1, A, u, -β, v)
167167
α = norm(v)
168-
α > 0 && @blas! v *= inv(α)
168+
v .*= inv(α)
169169
end
170170

171171
# Construct rotation Qhat_{k,2k+1}.
@@ -193,11 +193,9 @@ function lsmr_method!(log::ConvergenceHistory, x, A, b, v, h, hbar;
193193
ζbar = - sbar * ζbar
194194

195195
# Update h, h_hat, x.
196-
@blas! hbar *= - θbar * ρ / (ρold * ρbarold)
197-
@blas! hbar += h
198-
@blas! x +=/* ρbar))*hbar
199-
@blas! h *= - θnew / ρ
200-
@blas! h += v
196+
hbar .= hbar .* (-θbar * ρ / (ρold * ρbarold)) .+ h
197+
x .+=/* ρbar)) * hbar
198+
h .= h .* (-θnew / ρ) .+ v
201199

202200
##############################################################################
203201
##

src/lsqr.jl

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,12 @@ function lsqr_method!(log::ConvergenceHistory, x, A, b;
125125
alpha = zero(Tr)
126126
if beta > 0
127127
log.mtvps=1
128-
@blas! u *= inv(beta)
128+
u .*= inv(beta)
129129
Ac_mul_B!(v,A,u)
130130
alpha = norm(v)
131131
end
132132
if alpha > 0
133-
@blas! v *= inv(alpha)
133+
v .*= inv(alpha)
134134
end
135135
w = copy(v)
136136
wrho = similar(w)
@@ -158,21 +158,19 @@ function lsqr_method!(log::ConvergenceHistory, x, A, b;
158158
# Note that the following three lines are a band aid for a GEMM: X: C := αAB + βC.
159159
# This is already supported in A_mul_B! for sparse and distributed matrices, but not yet dense
160160
A_mul_B!(tmpm, A, v)
161-
@blas! u *= -alpha
162-
@blas! u += one(eltype(tmpm))*tmpm
161+
u .= -alpha .* u .+ tmpm
163162
beta = norm(u)
164163
if beta > 0
165164
log.mtvps+=1
166-
@blas! u *= inv(beta)
165+
u .*= inv(beta)
167166
Anorm = sqrt(abs2(Anorm) + abs2(alpha) + abs2(beta) + dampsq)
168167
# Note that the following three lines are a band aid for a GEMM: X: C := αA'B + βC.
169168
# This is already supported in Ac_mul_B! for sparse and distributed matrices, but not yet dense
170169
Ac_mul_B!(tmpn, A, u)
171-
@blas! v *= -beta
172-
@blas! v += one(eltype(tmpn))*tmpn
170+
v .= -beta .* v .+ tmpn
173171
alpha = norm(v)
174172
if alpha > 0
175-
@blas! v *= inv(alpha)
173+
v .*= inv(alpha)
176174
end
177175
end
178176

@@ -199,11 +197,9 @@ function lsqr_method!(log::ConvergenceHistory, x, A, b;
199197
t1 = phi /rho
200198
t2 = - theta/rho
201199

202-
@blas! x += t1*w
203-
@blas! w *= t2
204-
@blas! w += one(t2)*v
205-
@blas! wrho = w
206-
@blas! wrho *= inv(rho)
200+
x .+= t1*w
201+
w = t2 .* w .+ v
202+
wrho .= w .* inv(rho)
207203
ddnorm += norm(wrho)
208204

209205
# Use a plane rotation on the right to eliminate the

0 commit comments

Comments
 (0)