Skip to content

Commit 5729d6a

Browse files
authored
Fix issue with complex adjoint views (#14)
* Fix issue with complex adjoint views * Introduce ConjPtr
1 parent 81f8ff8 commit 5729d6a

File tree

4 files changed

+36
-13
lines changed

4 files changed

+36
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.2.3"
4+
version = "0.2.4"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/ArrayLayouts.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import Base: AbstractArray, AbstractMatrix, AbstractVector,
2727
similar, @_gc_preserve_end, @_gc_preserve_begin,
2828
@nexprs, @ncall, @ntuple, tuple_type_tail,
2929
all, any, isbitsunion, issubset, replace_in_print_matrix, replace_with_centered_mark,
30-
strides, unsafe_convert
30+
strides, unsafe_convert, first_index
3131

3232
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcasted,
3333
combine_eltypes, DefaultArrayStyle, instantiate, materialize,
@@ -71,8 +71,24 @@ _transpose_strides(a,b) = (b,a)
7171
strides(A::Adjoint) = _transpose_strides(strides(parent(A))...)
7272
strides(A::Transpose) = _transpose_strides(strides(parent(A))...)
7373

74+
"""
75+
ConjPtr{T}
76+
77+
represents that the entry is the complex-conjugate of the pointed to entry.
78+
"""
79+
struct ConjPtr{T}
80+
ptr::Ptr{T}
81+
end
82+
7483
unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real}) where T<:Real = unsafe_convert(Ptr{T}, parent(A))
7584
unsafe_convert(::Type{Ptr{T}}, A::Transpose) where T = unsafe_convert(Ptr{T}, parent(A))
85+
# work-around issue with complex conjugation of pointer
86+
unsafe_convert(::Type{Ptr{T}}, Ac::Adjoint{<:Complex}) where T<:Complex = unsafe_convert(ConjPtr{T}, parent(Ac))
87+
unsafe_convert(::Type{ConjPtr{T}}, Ac::Adjoint{<:Complex}) where T<:Complex = unsafe_convert(Ptr{T}, parent(Ac))
88+
function unsafe_convert(::Type{ConjPtr{T}}, V::SubArray{T,2}) where {T,N,P}
89+
kr, jr = parentindices(V)
90+
unsafe_convert(Ptr{T}, view(parent(V)', jr, kr))
91+
end
7692

7793
include("memorylayout.jl")
7894
include("muladd.jl")

src/muladd.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -257,64 +257,65 @@ end
257257
end
258258
end
259259

260+
# work around pointer issues
260261
@inline materialize!(M::BlasMatMulVecAdd{<:AbstractColumnMajor,<:AbstractStridedLayout,<:AbstractStridedLayout}) =
261262
_gemv!('N', M.α, M.A, M.B, M.β, M.C)
262263
@inline materialize!(M::BlasMatMulVecAdd{<:AbstractRowMajor,<:AbstractStridedLayout,<:AbstractStridedLayout}) =
263264
_gemv!('T', M.α, transpose(M.A), M.B, M.β, M.C)
264265
@inline materialize!(M::BlasMatMulVecAdd{<:ConjLayout{<:AbstractRowMajor},<:AbstractStridedLayout,<:AbstractStridedLayout,<:BlasComplex}) =
265-
_gemv!('C', M.α, M.A', M.B, M.β, M.C)
266+
_gemv!('C', M.α, adjoint(M.A), M.B, M.β, M.C)
266267

267268
@inline materialize!(M::BlasVecMulMatAdd{<:AbstractColumnMajor,<:AbstractColumnMajor,<:AbstractColumnMajor}) =
268269
_gemm!('N', 'N', M.α, M.A, M.B, M.β, M.C)
269270
@inline materialize!(M::BlasVecMulMatAdd{<:AbstractColumnMajor,<:AbstractRowMajor,<:AbstractColumnMajor}) =
270271
_gemm!('N', 'T', M.α, M.A, transpose(M.B), M.β, M.C)
271272
@inline materialize!(M::BlasVecMulMatAdd{<:AbstractColumnMajor,<:ConjLayout{<:AbstractRowMajor},<:AbstractColumnMajor,<:BlasComplex}) =
272-
_gemm!('N', 'C', M.α, M.A, M.B', M.β, M.C)
273+
_gemm!('N', 'C', M.α, M.A, adjoint(M.B), M.β, M.C)
273274

274275
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractColumnMajor,<:AbstractColumnMajor,<:AbstractColumnMajor}) =
275276
_gemm!('N', 'N', M.α, M.A, M.B, M.β, M.C)
276277
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractColumnMajor,<:AbstractRowMajor,<:AbstractColumnMajor}) =
277278
_gemm!('N', 'T', M.α, M.A, transpose(M.B), M.β, M.C)
278279
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractColumnMajor,<:ConjLayout{<:AbstractRowMajor},<:AbstractColumnMajor,<:BlasComplex}) =
279-
_gemm!('N', 'C', M.α, M.A, M.B', M.β, M.C)
280+
_gemm!('N', 'C', M.α, M.A, adjoint(M.B), M.β, M.C)
280281

281282
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractRowMajor,<:AbstractColumnMajor,<:AbstractColumnMajor}) =
282283
_gemm!('T', 'N', M.α, transpose(M.A), M.B, M.β, M.C)
283284
@inline materialize!(M::BlasMatMulMatAdd{<:ConjLayout{<:AbstractRowMajor},<:AbstractColumnMajor,<:AbstractColumnMajor,<:BlasComplex}) =
284-
_gemm!('C', 'N', M.α, M.A', M.B, M.β, M.C)
285+
_gemm!('C', 'N', M.α, adjoint(M.A), M.B, M.β, M.C)
285286

286287
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractRowMajor,<:AbstractRowMajor,<:AbstractColumnMajor}) =
287288
_gemm!('T', 'T', M.α, transpose(M.A), transpose(M.B), M.β, M.C)
288289
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractRowMajor,<:ConjLayout{<:AbstractRowMajor},<:AbstractColumnMajor,<:BlasComplex}) =
289-
_gemm!('T', 'C', M.α, transpose(M.A), M.B', M.β, M.C)
290+
_gemm!('T', 'C', M.α, transpose(M.A), adjoint(M.B), M.β, M.C)
290291

291292
@inline materialize!(M::BlasMatMulMatAdd{<:ConjLayout{<:AbstractRowMajor},<:AbstractRowMajor,<:AbstractColumnMajor,<:BlasComplex}) =
292-
_gemm!('C', 'T', M.α, M.A', M.B', M.β, M.C)
293+
_gemm!('C', 'T', M.α, adjoint(M.A), transpose(M.B), M.β, M.C)
293294
@inline materialize!(M::BlasMatMulMatAdd{<:ConjLayout{<:AbstractRowMajor},<:ConjLayout{<:AbstractRowMajor},<:AbstractColumnMajor,<:BlasComplex}) =
294-
_gemm!('C', 'C', M.α, M.A', M.B', M.β, M.C)
295+
_gemm!('C', 'C', M.α, adjoint(M.A), adjoint(M.B), M.β, M.C)
295296

296297
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractColumnMajor,<:AbstractColumnMajor,<:AbstractRowMajor}) =
297298
_gemm!('T', 'T', M.α, M.B, M.A, M.β, transpose(M.C))
298299
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractColumnMajor,<:AbstractColumnMajor,<:ConjLayout{<:AbstractRowMajor},<:BlasComplex}) =
299-
_gemm!('C', 'C', M.α, M.B, M.A, M.β, M.C')
300+
_gemm!('C', 'C', M.α, M.B, M.A, M.β, adjoint(M.C))
300301

301302
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractColumnMajor,<:AbstractRowMajor,<:AbstractRowMajor}) =
302303
_gemm!('N', 'T', M.α, transpose(M.B), M.A, M.β, transpose(M.C))
303304
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractColumnMajor,<:AbstractRowMajor,<:ConjLayout{<:AbstractRowMajor},<:BlasComplex}) =
304305
_gemm!('N', 'T', M.α, transpose(M.B), M.A, M.β, M.C')
305306
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractColumnMajor,<:ConjLayout{<:AbstractRowMajor},<:ConjLayout{<:AbstractRowMajor},<:BlasComplex}) =
306-
_gemm!('N', 'C', M.α, M.B', M.A, M.β, M.C')
307+
_gemm!('N', 'C', M.α, adjoint(M.B), M.A, M.β, adjoint(M.C))
307308

308309
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractRowMajor,<:AbstractColumnMajor,<:AbstractRowMajor}) =
309310
_gemm!('T', 'N', M.α, M.B, transpose(M.A), M.β, transpose(M.C))
310311
@inline materialize!(M::BlasMatMulMatAdd{<:ConjLayout{<:AbstractRowMajor},<:AbstractColumnMajor,<:ConjLayout{<:AbstractRowMajor},<:BlasComplex}) =
311-
_gemm!('C', 'N', M.α, M.B, M.A', M.β, M.C')
312+
_gemm!('C', 'N', M.α, M.B, adjoint(M.A), M.β, adjoint(M.C))
312313

313314

314315
@inline materialize!(M::BlasMatMulMatAdd{<:AbstractRowMajor,<:AbstractRowMajor,<:AbstractRowMajor}) =
315316
_gemm!('N', 'N', M.α, transpose(M.B), transpose(M.A), M.β, transpose(M.C))
316317
@inline materialize!(M::BlasMatMulMatAdd{<:ConjLayout{<:AbstractRowMajor},<:ConjLayout{<:AbstractRowMajor},<:ConjLayout{<:AbstractRowMajor},<:BlasComplex}) =
317-
_gemm!('N', 'N', M.α, M.B', M.A', M.β, M.C')
318+
_gemm!('N', 'N', M.α, adjoint(M.B), adjoint(M.A), M.β, adjoint(M.C))
318319

319320

320321
###

test/test_muladd.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ end
168168
c .= MulAdd(3one(T),Ac,b,2one(T),c)
169169
@test all(c .=== BLAS.gemv!(trans, 3one(T), A, b, 2one(T), copy(b)))
170170
end
171+
172+
C = randn(6,6) + im*randn(6,6)
173+
V = view(C', 2:3, 3:4)
174+
c = randn(2) + im*randn(2)
175+
@test all(muladd!(1.0+0im,V,c,0.0+0im,similar(c,2)) .=== BLAS.gemv!('C', 1.0+0im, Matrix(V'), c, 0.0+0im, similar(c,2)))
176+
@test all(muladd!(1.0+0im,V',c,0.0+0im,similar(c,2)) .=== BLAS.gemv!('N', 1.0+0im, Matrix(V'), c, 0.0+0im, similar(c,2)))
171177
end
172178

173179
@testset "gemm adjtrans" begin

0 commit comments

Comments
 (0)