Skip to content

Commit a64c06e

Browse files
author
Will Kimmerer
authored
Slightly improve test coverage, non-standard map becomes apply, remove a few ambiguities, fix broadcast error (#63)
* initial * update asjulia and constructors * improve support for manual TypedOperators, hopefully improve test cov * test throws for ops * better coverage, concat corrections/coverage * tmtc
1 parent 854355d commit a64c06e

31 files changed

+507
-3066
lines changed

src/SuiteSparseGraphBLAS.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ export clear!, extract, extract!, subassign!, assign!, hvcat! #array functions
113113

114114
#operations
115115
export mul, select, select!, eadd, eadd!, emul, emul!, map, map!, gbtranspose, gbtranspose!,
116-
gbrand, eunion, eunion!, mask, mask!
116+
gbrand, eunion, eunion!, mask, mask!, apply, apply!
117117
# Reexports from LinAlg
118118
export diag, diagm, mul!, kron, kron!, transpose, reduce, tril, triu
119119

src/asjulia.jl

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
function as(f::Function, ::Type{<:Union{Matrix, Vector}}, A::GBVecOrMat{T}; dropzeros=false, freeunpacked=false) where {T}
2-
if gbget(A, SPARSITY_STATUS) != GBDENSE
1+
function as(f::Function, type::Type{<:Union{Matrix, Vector}}, A::GBVecOrMat{T}; dropzeros=false, freeunpacked=false, nomodstructure = false) where {T}
2+
(type == Matrix && !(A isa GBMatrix)) && throw(ArgumentError("Cannot wrap $(typeof(A)) in a Matrix."))
3+
(type == Vector && !(A isa GBVector)) && throw(ArgumentError("Cannot wrap $(typeof(A)) in a Vector."))
4+
if gbget(A, SPARSITY_STATUS) != Int64(GBDENSE)
35
X = similar(A)
46
if X isa GBVector
57
X[:] = zero(T)
68
else
79
X[:,:] = zero(T)
810
end
9-
# I don't like this, it defeats the purpose of this method, which is to make no copies.
10-
# But somehow maintaining the input A in its original form is key to the to_vec implementation
11-
# for ChainRules/FiniteDiff. Temporarily it's fine, it's no worse than it originally was.
12-
# TODO: fix this issue with the ChainRules code.
13-
A = eadd(X, A)
11+
if nomodstructure
12+
A = eadd!(X, A, X)
13+
else
14+
A = eadd!(A, X, A)
15+
end
1416
end
17+
1518
array = _unpackdensematrix!(A)
1619
result = try
17-
f(array, A)
20+
f(array)
1821
finally
1922
if freeunpacked
2023
ccall(:jl_free, Cvoid, (Ptr{T},), pointer(array))
@@ -28,11 +31,11 @@ function as(f::Function, ::Type{<:Union{Matrix, Vector}}, A::GBVecOrMat{T}; drop
2831
return result
2932
end
3033

31-
function as(f::Function, ::SparseMatrixCSC, A::GBMatrix{T}; freeunpacked=false) where {T}
34+
function as(f::Function, ::Type{SparseMatrixCSC}, A::GBMatrix{T}; freeunpacked=false) where {T}
3235
colptr, rowidx, values = _unpackcscmatrix!(A)
3336
array = SparseMatrixCSC{T, LibGraphBLAS.GrB_Index}(size(A, 1), size(A, 2), colptr, rowidx, values)
3437
result = try
35-
f(array, A)
38+
f(array)
3639
finally
3740
if freeunpacked
3841
ccall(:jl_free, Cvoid, (Ptr{LibGraphBLAS.GrB_Index},), pointer(colptr))
@@ -45,11 +48,11 @@ function as(f::Function, ::SparseMatrixCSC, A::GBMatrix{T}; freeunpacked=false)
4548
return result
4649
end
4750

48-
function as(f::Function, ::SparseVector, A::GBVector{T}; freeunpacked=false) where {T}
51+
function as(f::Function, ::Type{SparseVector}, A::GBVector{T}; freeunpacked=false) where {T}
4952
colptr, rowidx, values = _unpackcscmatrix!(A)
5053
vector = SparseVector{T, LibGraphBLAS.GrB_Index}(size(A, 1), rowidx, values)
5154
result = try
52-
f(vector, A)
55+
f(vector)
5356
finally
5457
if freeunpacked
5558
ccall(:jl_free, Cvoid, (Ptr{LibGraphBLAS.GrB_Index},), pointer(colptr))
@@ -64,37 +67,41 @@ end
6467

6568

6669
function Base.Matrix(A::GBMatrix)
67-
return as(Matrix, A) do arr, _
70+
# we use nomodstructure here to avoid the pitfall of densifying A.
71+
return as(Matrix, A; nomodstructure=true) do arr
6872
return copy(arr)
6973
end
7074
end
7175

7276
function Matrix!(A::GBMatrix)
73-
return as(Matrix, A; freeunpacked=true) do arr, _
77+
# we use nomodstructure here to avoid the pitfall of densifying A.
78+
return as(Matrix, A; freeunpacked=true) do arr
7479
return copy(arr)
7580
end
7681
end
7782

7883
function Base.Vector(v::GBVector)
79-
return as(Vector, v) do vec, _
84+
# we use nomodstructure here to avoid the pitfall of densifying A.
85+
return as(Vector, v; nomodstructure=true) do vec
8086
return copy(vec)
8187
end
8288
end
8389

8490
function Vector!(v::GBVector)
85-
return as(Vector, v; freeunpacked=true) do vec, _
91+
# we use nomodstructure here to avoid the pitfall of densifying A.
92+
return as(Vector, v; freeunpacked=true) do vec
8693
return copy(vec)
8794
end
8895
end
8996

9097
function SparseArrays.SparseMatrixCSC(A::GBMatrix)
91-
return as(SparseMatrixCSC, A) do arr, _
98+
return as(SparseMatrixCSC, A) do arr
9299
return copy(arr)
93100
end
94101
end
95102

96103
function SparseArrays.SparseVector(v::GBVector)
97-
return as(SparseVector, v) do arr, _
104+
return as(SparseVector, v) do arr
98105
return copy(arr)
99106
end
100107
end

src/chainrules/constructorrules.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Dense vector construction
22
function frule(
3-
(_, Δv),
3+
(_, Δv)::Tuple,
44
::Type{<:GBVector},
55
v::Vector{T}
66
) where {T}
@@ -16,7 +16,7 @@ end
1616

1717
# Dense matrix construction
1818
function frule(
19-
(_, ΔA),
19+
(_, ΔA)::Tuple,
2020
::Type{<:GBMatrix},
2121
A::Matrix{T}
2222
) where {T}
@@ -32,7 +32,7 @@ end
3232

3333
# Dense matrix from vector (n x 1 matrix)
3434
function frule(
35-
(_, ΔA),
35+
(_, ΔA)::Tuple,
3636
::Type{<:GBMatrix},
3737
A::Vector{T}
3838
) where {T}
@@ -50,7 +50,7 @@ end
5050

5151
# Sparse Vector
5252
function frule(
53-
(_, _, Δv),
53+
(_, _, Δv)::Tuple,
5454
::Type{<:GBVector},
5555
I::AbstractVector{U},
5656
v::Vector{T}
@@ -67,7 +67,7 @@ end
6767

6868
# Sparse Matrix
6969
function frule(
70-
(_,_,_,Δv),
70+
(_,_,_,Δv)::Tuple,
7171
::Type{<:GBMatrix},
7272
I::AbstractVector{U},
7373
J::AbstractVector{U},
@@ -89,7 +89,7 @@ function rrule(
8989
end
9090

9191
function frule(
92-
(_,_,Δv),
92+
(_,_,Δv)::Tuple,
9393
::Type{<:GBMatrix},
9494
I::AbstractVector{U},
9595
v::Vector{T}
@@ -109,7 +109,7 @@ function rrule(
109109
end
110110

111111
function frule(
112-
(_,ΔS),
112+
(_,ΔS)::Tuple,
113113
::Type{GBMatrix},
114114
S::SparseMatrixCSC{T}
115115
) where {T}
@@ -128,7 +128,7 @@ function rrule(
128128
end
129129

130130
function frule(
131-
(_,ΔS),
131+
(_,ΔS)::Tuple,
132132
::Type{GBMatrix},
133133
S::SparseVector{T}
134134
) where {T}

src/chainrules/ewiserules.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#emul TIMES
22
function frule(
3-
(_, ΔA, ΔB, _),
3+
(_, ΔA, ΔB, _)::Tuple,
44
::typeof(emul),
55
A::GBArray,
66
B::GBArray,
@@ -10,7 +10,7 @@ function frule(
1010
∂Ω = emul(unthunk(ΔA), B, *) + emul(unthunk(ΔB), A, *)
1111
return Ω, ∂Ω
1212
end
13-
function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray)
13+
function frule((_, ΔA, ΔB)::Tuple, ::typeof(emul), A::GBArray, B::GBArray)
1414
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, *)
1515
end
1616

@@ -37,7 +37,7 @@ end
3737
######
3838

3939
function frule(
40-
(_, ΔA, ΔB, _),
40+
(_, ΔA, ΔB, _)::Tuple,
4141
::typeof(eadd),
4242
A::GBArray,
4343
B::GBArray,
@@ -47,7 +47,7 @@ function frule(
4747
∂Ω = eadd(unthunk(ΔA), unthunk(ΔB), +)
4848
return Ω, ∂Ω
4949
end
50-
function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray)
50+
function frule((_, ΔA, ΔB)::Tuple, ::typeof(eadd), A::GBArray, B::GBArray)
5151
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, +)
5252
end
5353

src/chainrules/maprules.jl

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,84 @@
1-
# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather
1+
# Per Lyndon. Needs adaptation, and/or needs redefinition of apply to use functions rather
22
# than AbstractOp.
3-
#function rrule(map, f, xs)
4-
# # Rather than 3 maps really want 1 multimap
5-
# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x)
6-
# ys = map(first, ys_and_pullbacks)
7-
# pullbacks = map(last, ys_and_pullbacks)
8-
# function map_pullback(dys)
3+
#function rrule(apply, f, xs)
4+
# # Rather than 3 applys really want 1 multiapply
5+
# ys_and_pullbacks = apply(x->rrule(f, x), xs) #Take this to ys = apply(f, x)
6+
# ys = apply(first, ys_and_pullbacks)
7+
# pullbacks = apply(last, ys_and_pullbacks)
8+
# function apply_pullback(dys)
99
# _call(f, x) = f(x)
10-
# dfs_and_dxs = map(_call, pullbacks, dys)
10+
# dfs_and_dxs = apply(_call, pullbacks, dys)
1111
# # but in your case you know it will be NoTangent() so can skip
1212
# df = sum(first, dfs_and_dxs)
13-
# dxs = map(last, dfs_and_dxs)
13+
# dxs = apply(last, dfs_and_dxs)
1414
# return NoTangent(), df, dxs
1515
# end
16-
# return ys, map_pullback
16+
# return ys, apply_pullback
1717
#end
18-
macro scalarmaprule(func, derivative)
18+
macro scalarapplyrule(func, derivative)
1919
return ChainRulesCore.@strip_linenos quote
2020
function ChainRulesCore.frule(
21-
(_, _, $(esc(:ΔA))),
22-
::typeof(Base.map),
21+
(_, _, $(esc(:ΔA)))::Tuple,
22+
::typeof(apply),
2323
::typeof($(func)),
2424
$(esc(:A))::GBArray
2525
)
26-
$(esc()) = map($(esc(func)), $(esc(:A)))
26+
$(esc()) = apply($(esc(func)), $(esc(:A)))
2727
return $(esc()), $(esc(derivative)) .* unthunk($(esc(:ΔA)))
2828
end
2929
function ChainRulesCore.rrule(
30-
::typeof(Base.map),
30+
::typeof(apply),
3131
::typeof($(func)),
3232
$(esc(:A))::GBArray
3333
)
34-
$(esc()) = map($(esc(func)), $(esc(:A)))
35-
function mapback($(esc(:ΔA)))
34+
$(esc()) = apply($(esc(func)), $(esc(:A)))
35+
function applyback($(esc(:ΔA)))
3636
NoTangent(), NoTangent(), $(esc(derivative)) .* $(esc(:ΔA))
3737
end
38-
return $(esc()), mapback
38+
return $(esc()), applyback
3939
end
4040
end
4141
end
4242

4343
function ChainRulesCore.frule(
44-
(_,_,ΔA),
45-
::typeof(map),
44+
(_,_,ΔA)::Tuple,
45+
::typeof(apply),
4646
::typeof(sqrt),
4747
A::Array
4848
)
49-
Ω = map(sqrt, A)
49+
Ω = apply(sqrt, A)
5050
return Ω, inv.(2 .* Ω)
5151
end
5252

5353
#Trig
54-
@scalarmaprule sin cos.(A)
55-
@scalarmaprule cos -sin.(A)
56-
@scalarmaprule tan @. 1 +^ 2)
54+
@scalarapplyrule sin cos.(A)
55+
@scalarapplyrule cos -sin.(A)
56+
@scalarapplyrule tan @. 1 +^ 2)
5757

5858
#Hyperbolic Trig
59-
@scalarmaprule sinh cosh.(A)
60-
@scalarmaprule cosh sinh.(A)
61-
@scalarmaprule tanh @. 1 -^ 2)
59+
@scalarapplyrule sinh cosh.(A)
60+
@scalarapplyrule cosh sinh.(A)
61+
@scalarapplyrule tanh @. 1 -^ 2)
6262

63-
@scalarmaprule inv -.^ 2)
64-
@scalarmaprule exp Ω
63+
@scalarapplyrule inv -.^ 2)
64+
@scalarapplyrule exp Ω
6565

66-
@scalarmaprule abs sign.(A)
66+
@scalarapplyrule abs sign.(A)
6767
#Anything that uses MINV fails the isapprox tests :().
6868
# Since in the immortal words of Miha - "FiniteDiff is smarter than you", these shouldn't be enabled.
69-
#@scalarmaprule UnaryOps.ASIN @. inv(sqrt.(1 - A ^ 2))
70-
#@scalarmaprule UnaryOps.ACOS @. inv(sqrt.(1 - A ^ 2))
71-
#@scalarmaprule UnaryOps.ATAN @. inv(1 + A ^ 2)
72-
#@scalarmaprule UnaryOps.SQRT inv.(2 .* Ω)
69+
#@scalarapplyrule UnaryOps.ASIN @. inv(sqrt.(1 - A ^ 2))
70+
#@scalarapplyrule UnaryOps.ACOS @. inv(sqrt.(1 - A ^ 2))
71+
#@scalarapplyrule UnaryOps.ATAN @. inv(1 + A ^ 2)
72+
#@scalarapplyrule UnaryOps.SQRT inv.(2 .* Ω)
7373

7474
function frule(
75-
(_, _, ΔA),
76-
::typeof(map),
75+
(_, _, ΔA)::Tuple,
76+
::typeof(apply),
7777
::typeof(identity),
7878
A::GBArray
7979
)
8080
return (A, ΔA)
8181
end
82-
function rrule(::typeof(map), ::typeof(identity), A::GBArray)
82+
function rrule(::typeof(apply), ::typeof(identity), A::GBArray)
8383
return A, (ΔΩ) -> (NoTangent(), NoTangent(), ΔΩ)
8484
end

src/chainrules/mulrules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#PLUS REDUCERS:
22
###############
33
function frule(
4-
(_, ΔA, ΔB),
4+
(_, ΔA, ΔB)::Tuple,
55
::typeof(mul),
66
A::GBArray,
77
B::GBArray
88
)
99
frule((nothing, ΔA, ΔB, nothing), mul, A, B, (+, *))
1010
end
1111
function frule(
12-
(_, ΔA, ΔB, _),
12+
(_, ΔA, ΔB, _)::Tuple,
1313
::typeof(mul),
1414
A::GBArray,
1515
B::GBArray,
@@ -64,7 +64,7 @@ end
6464

6565
# PLUS_PLUS:
6666
function frule(
67-
(_, ΔA, ΔB, _),
67+
(_, ΔA, ΔB, _)::Tuple,
6868
::typeof(mul),
6969
A::GBArray,
7070
B::GBArray,
@@ -91,7 +91,7 @@ end
9191

9292
# PLUS_MINUS:
9393
function frule(
94-
(_, ΔA, ΔB, _),
94+
(_, ΔA, ΔB, _)::Tuple,
9595
::typeof(mul),
9696
A::GBArray,
9797
B::GBArray,
@@ -118,7 +118,7 @@ end
118118

119119
# PLUS_FIRST:
120120
function frule(
121-
(_, ΔA, ΔB, _),
121+
(_, ΔA, ΔB, _)::Tuple,
122122
::typeof(mul),
123123
A::GBArray,
124124
B::GBArray,
@@ -145,7 +145,7 @@ end
145145

146146
# PLUS_SECOND:
147147
function frule(
148-
(_, ΔA, ΔB, _),
148+
(_, ΔA, ΔB, _)::Tuple,
149149
::typeof(mul),
150150
A::GBArray,
151151
B::GBArray,

0 commit comments

Comments
 (0)