Skip to content

Commit 5b928aa

Browse files
committed
sufficient coverage of cholmod, wrappers still very broken. improve broadcasting slightly
1 parent 8274567 commit 5b928aa

File tree

11 files changed

+378
-327
lines changed

11 files changed

+378
-327
lines changed

src/SuiteSparseGraphBLAS.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ include("chainrules/constructorrules.jl")
101101
include("serialization.jl")
102102

103103
#EXPERIMENTAL
104+
include("linalg.jl")
104105
include("misc.jl")
105106
include("mmread.jl")
106107
include("iterator.jl")
@@ -133,7 +134,7 @@ export extract, extract!, subassign!, assign!, hvcat! #array functions
133134

134135
#operations
135136
export select, select!, eadd, eadd!, emul, emul!, gbtranspose, gbtranspose!,
136-
gbrand, eunion, eunion!, mask, mask!, apply, apply!, setfill, setfill!
137+
gbrand, eunion, eunion!, mask, mask!, apply, apply!, setfill, setfill!, gbrandn
137138
# Reexports from LinAlg
138139
export diag, diagm, mul!, kron, kron!, transpose, reduce, tril, triu
139140

src/abstractgbarray.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,12 @@ function build!(A::AbstractGBMatrix{T}, I::AbstractVector, J::AbstractVector, x:
247247
@wraperror LibGraphBLAS.GxB_Matrix_build_Scalar(
248248
A,
249249
Vector{LibGraphBLAS.GrB_Index}(decrement!(I)),
250-
Vector{LibGraphBLAS.GrB_Index}(decrement!(J)),
250+
Vector{LibGraphBLAS.GrB_Index}(I !== J ? decrement!(J) : J),
251251
x,
252252
length(I)
253253
)
254254
increment!(I)
255-
increment!(J)
255+
I !== J && (increment!(J))
256256
return A
257257
end
258258

@@ -296,7 +296,7 @@ for T ∈ valid_vec
296296
length(X) == length(I) == length(J) ||
297297
DimensionMismatch("I, J and X must have the same length")
298298
decrement!(I)
299-
decrement!(J)
299+
I !== J && (decrement!(J))
300300
@wraperror LibGraphBLAS.$func(
301301
A,
302302
I,
@@ -306,7 +306,7 @@ for T ∈ valid_vec
306306
combine
307307
)
308308
increment!(I)
309-
increment!(J)
309+
I !== J && (increment!(J))
310310
return A
311311
end
312312
end
@@ -535,7 +535,7 @@ function subassign!(
535535
desc = _handledescriptor(desc; out=C, in1=A)
536536
mask = _handlemask!(desc, mask)
537537
I = decrement!(I)
538-
J = decrement!(J)
538+
I !== J && (J = decrement!(J))
539539
rereshape = false
540540
sz1 = size(A, 1)
541541
if !(eltype(A) <: valid_union) || !(eltype(C) <: valid_union)
@@ -553,7 +553,7 @@ function subassign!(
553553
parent(A), true, sz1, 1, C_NULL)
554554
end
555555
increment!(I)
556-
increment!(J)
556+
I !== J && (increment!(J))
557557
return A
558558
end
559559

@@ -565,12 +565,12 @@ function subassign!(C::AbstractGBArray{T}, x, I, J;
565565
I, ni = idx(I)
566566
J, nj = idx(J)
567567
I = decrement!(I)
568-
J = decrement!(J)
568+
I !== J && (J = decrement!(J))
569569
desc = _handledescriptor(desc; out=C)
570570
mask = _handlemask!(desc, mask)
571571
_subassign(C, x, I, ni, J, nj, mask, _handleaccum(accum, storedeltype(C)), desc)
572572
increment!(I)
573-
increment!(J)
573+
I !== J && (decrement!(J))
574574
return x
575575
end
576576

@@ -638,13 +638,13 @@ function assign!(
638638
desc = _handledescriptor(desc; in1=A, out=C)
639639
mask = _handlemask!(desc, mask)
640640
I = decrement!(I)
641-
J = decrement!(J)
641+
I !== J && (J = decrement!(J))
642642
if !(eltype(A) <: valid_union) || !(eltype(C) <: valid_union)
643643
A = LinearAlgebra.copy_oftype(A, eltype(C))
644644
end
645645
@wraperror LibGraphBLAS.GrB_Matrix_assign(C, mask, _handleaccum(accum, storedeltype(C)), parent(A), I, ni, J, nj, desc)
646646
increment!(I)
647-
increment!(J)
647+
I !== J && (decrement!(J))
648648
return A
649649
end
650650

@@ -656,12 +656,12 @@ function assign!(C::AbstractGBArray{T}, x, I, J;
656656
I, ni = idx(I)
657657
J, nj = idx(J)
658658
I = decrement!(I)
659-
J = decrement!(J)
659+
I !== J && (J = decrement!(J))
660660
desc = _handledescriptor(desc; out=C)
661661
mask = _handlemask!(desc, mask)
662662
_assign(C, x, I, ni, J, nj, mask, _handleaccum(accum, storedeltype(C)), desc)
663663
increment!(I)
664-
increment!(J)
664+
I !== J && (decrement!(J))
665665
return x
666666
end
667667

@@ -719,7 +719,7 @@ function GBDiagonal!(C::AbstractGBMatrix, v::AbstractGBVector, k::Integer=0; des
719719
return C
720720
end
721721
function GBDiagonal!(C::AbstractGBMatrix{T}, v::AbstractVector, k::Integer=0; desc = nothing) where T
722-
v2 = GBShallowVector(convert(DenseVector{T}, v))
722+
v2 = GBShallowVector(convert(Vector{T}, v))
723723
GBDiagonal!(C, v2, k; desc)
724724
end
725725
function GBDiagonal!(C::AbstractGBMatrix, D::Diagonal; desc = nothing)
@@ -736,7 +736,6 @@ function GBDiagonal(v::AbstractGBVector, k::Integer=0; desc = nothing)
736736
GBDiagonal!(C, v, k; desc)
737737
end
738738

739-
740739
# Type dependent functions build, setindex, getindex, and findnz:
741740
for T valid_vec
742741
if T gxb_vec

src/iterator.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
# TODO: Make this file less loathsome.
22

3+
# These iterators are only somewhat useful.
4+
# The biggest issue is that we can't, at the moment, use one to change values in another.
5+
# So in place map! is doable like this, as is just simply getting values.
6+
# But if you wanted to say `C[I] = <transform>(A[I])` even with identical structures
7+
# there's no method. Non-identical patterns are off the table entirely.
8+
39
abstract type IndexIteratorType end
410
struct IndicesIterator <: IndexIteratorType end # return the indices as integers
511
struct NeighborIterator <: IndexIteratorType end # Used only for Vector iterators
612
struct NoIndexIterator <: IndexIteratorType end # Used when we don't want the indices, just the value
13+
struct IteratorIterator <: IndexIteratorType end # return the iterator object itself. Spooky!
714
# just returns the free index, only useful for single column/row iterators.
815

916
abstract type AbstractGBIterator end
@@ -143,13 +150,20 @@ end
143150
return unsafe_load(Ptr{T}(I.p.Ax[]), I.p.iso[] ? 1 : I.p.p[] + 1)
144151
end
145152

153+
@inline function setval(I::GxBIterator{<:Any, T}, x::T) where T
154+
I.p.iso[] && throw(ArgumentError("Cannot set value of iso valued matrix using iterator."))
155+
unsafe_store!(Ptr{T}(I.p.Ax[]), x, I.p.p[] + 1)
156+
end
157+
146158
# TODO: Inelegant
147159
@inline get_element(I::GxBIterator{<:Any, <:Any, true, IndicesIterator()}) = ((getrow(I), getcol(I)), getval(I))
148160
@inline get_element(I::GxBIterator{<:Any, <:Any, false, IndicesIterator()}) = (getrow(I), getcol(I))
149161
@inline get_element(I::GxBIterator{<:Any, <:Any, true, NeighborIterator()}) = (increment(_rc_geti(I)), getval(I))
150162
@inline get_element(I::GxBIterator{<:Any, <:Any, false, NeighborIterator()}) = increment(_rc_geti(I))
151163
@inline get_element(I::GxBIterator{<:Any, <:Any, true, NoIndexIterator()}) = getval(I)
152164
@inline get_element(::GxBIterator{<:Any, <:Any, false, NoIndexIterator()}) = throw(ArgumentError("Must iterate over either indices or values."))
165+
@inline get_element(I::GxBIterator{<:Any, <:Any, <:Any, IteratorIterator()}) = I
166+
153167

154168
struct VectorIterator{B, O, T, IterateValues, IterationType, I<:GxBIterator}
155169
iterator::I
@@ -228,10 +242,14 @@ end
228242
end
229243

230244
get_element(I::VectorIterator) = get_element(I.iterator)
245+
get_element(I::VectorIterator{<:Any, <:Any, <:Any, <:Any, IteratorIterator()}) = I
231246

232247
const RowIterator{B, T, IterateValues, IterationType} = VectorIterator{B, RowMajor(), T, IterateValues, IterationType}
233248
const ColIterator{B, T, IterateValues, IterationType} = VectorIterator{B, ColMajor(), T, IterateValues, IterationType}
234249

250+
defaultiteration(::Integer) = NeighborIterator()
251+
defaultiteration(::AbstractVector) = IndicesIterator()
252+
235253
RowIterator(A::AbstractGBArray, v, iteratevalues::Bool, indexiteration::IndexIteratorType = NeighborIterator()) =
236254
storageorder(A) === RowMajor() ? VectorIterator{iteratevalues, indexiteration}(A, v) :
237255
throw(ArgumentError("A is not in RowMajor() order. Row iteration is only supported on RowMajor AbstractGBArrays. Try setstorageorder[!]"))
@@ -255,4 +273,9 @@ iteraterows(A::AbstractGBArray, v, indexiteration::IndexIteratorType = NeighborI
255273
function Base.iterate(I::VectorIterator)
256274
return _seek(I.iterator, I.v isa Int64 ? I.v : I.v.start) == LibGraphBLAS.GrB_NO_VALUE ? knext(I) : (get_element(I), nothing)
257275
end
258-
Base.iterate(I::VectorIterator, ::Nothing) = inext(I)
276+
Base.iterate(I::VectorIterator, ::Nothing) = inext(I)
277+
278+
Base.getindex(A::AbstractArray, I::GxBIterator) = getindex(A, getrow(I), getcol(I))
279+
Base.getindex(A::AbstractArray, v::VectorIterator) = getindex(A, v.iterator)
280+
Base.setindex!(A::AbstractArray, x, I::GxBIterator) = setindex!(A, x, getrow(I), getcol(I))
281+
Base.setindex!(A::AbstractArray, x, v::VectorIterator) = setindex!(A, x, v.iterator)

src/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# not quite as fast as it *could* be, we have to construct the sparse representation of
2+
# a diagonal. Which means ~triple the size of a basic Diagonal matrix.
3+
# SSGrB doesn't have an internal representation for a Diagonal.
4+
LinearAlgebra.:\(D::Diagonal, B::AbstractGBMatrix) = *(D, B, (any, \))

src/operations/broadcasts.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ modifying(::typeof(*)) = mul!
6363
modifying(::typeof(eadd)) = eadd!
6464
modifying(::typeof(emul)) = emul!
6565

66+
# TODO: Fix this horrifically ugly function.
67+
# TODO: Remove the requirement that vector -> matrix
68+
# broadcast be done on certain side.
69+
# We can add an `iscommutative` to solve this.
70+
# as well as some function to get the reversed version of the operator.
6671
@inline function Base.copy(bc::Broadcast.Broadcasted{GBMatrixStyle})
6772
f = bc.f
6873
l = length(bc.args)
@@ -92,6 +97,28 @@ modifying(::typeof(emul)) = emul!
9297
if right isa StridedArray
9398
right = pack(right; fill = left isa GBArrayOrTranspose ? getfill(left) : nothing)
9499
end
100+
# TODO: We want to expose the broadcasting of Vectors into Matrices.
101+
# The only problem is we need a notion of commutativity.
102+
# This works fine for builtins, we can define commutativity pretty easily
103+
# for many operations. But for non-builtins we'd need an API of sorts.
104+
# To get around this for now we will require that Vectors be on the left
105+
# and transposed vectors be on the right.
106+
if left isa AbstractGBVector && right isa GBMatrixOrTranspose
107+
return *(Diagonal(left), right, (any, f))
108+
end
109+
if left isa GBMatrixOrTranspose && right isa Transpose{<:Any, <:AbstractGBVector}
110+
return *(left, Diagonal(right), (any, f))
111+
end
112+
if left isa GBMatrixOrTranspose && right isa AbstractGBVector
113+
throw(ArgumentError(
114+
"Broadcasting a GBVector into a GBMatrix is only currently " *
115+
"supported with the GBVector on the left."))
116+
end
117+
if right isa GBMatrixOrTranspose && left isa Transpose{<:Any, <:AbstractGBVector}
118+
throw(ArgumentError(
119+
"Broadcasting a Transpose{<:Any, <:AbstractGBVector} into a GBMatrix" *
120+
" is only currently supported with the GBVector on the right."))
121+
end
95122
if left isa GBArrayOrTranspose && right isa GBArrayOrTranspose
96123
add = defaultadd(f)
97124
return add(left, right, f)

src/operations/extract.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ function extract!(
6868
desc = _handledescriptor(desc; out=C, in1 = A)
6969
mask = _handlemask!(desc, mask)
7070
I = decrement!(I)
71-
J = decrement!(J)
71+
I !== J && (J = decrement!(J))
7272
@wraperror LibGraphBLAS.GrB_Matrix_extract(C, mask, _handleaccum(accum, storedeltype(C)), parent(A), I, ni, J, nj, desc)
73-
I isa AbstractVector && increment!(I)
74-
J isa AbstractVector && increment!(J)
73+
increment!(I)
74+
I !== J && increment!(J)
7575
return C
7676
end
7777
"""
@@ -140,7 +140,7 @@ function extract!(
140140
desc = _handledescriptor(desc; out=w)
141141
mask = _handlemask!(desc, mask)
142142
@wraperror LibGraphBLAS.GrB_Matrix_extract(w, mask, _handleaccum(accum, storedeltype(w)), u, I, ni, UInt64[0], 1, desc)
143-
I isa AbstractVector && increment!(I)
143+
increment!(I)
144144
return w
145145
end
146146

src/operations/mul.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,13 @@ function Base.:*(rig::TypedSemiring)
191191
end
192192

193193
# Diagonal
194-
function Base.:*(D::Diagonal, A::G) where
194+
function Base.:*(D::Diagonal, A::G, op = (+, *); mask = nothing, accum = nothing, desc = nothing) where
195195
{G <: Union{Transpose{T, <:SuiteSparseGraphBLAS.AbstractGBArray{T1, F, O}} where {T, T1, F, O},
196196
SuiteSparseGraphBLAS.AbstractGBArray{T, F, O, 2} where {T, F, O}}}
197-
return *(G(D), A)
197+
return *(G(D), A, op; mask, accum, desc)
198198
end
199-
function Base.:*(A::G, D::Diagonal) where
199+
function Base.:*(A::G, D::Diagonal, op = (+, *); mask = nothing, accum = nothing, desc = nothing) where
200200
{G <: Union{Transpose{T, <:SuiteSparseGraphBLAS.AbstractGBArray{T1, F, O}} where {T, T1, F, O},
201201
SuiteSparseGraphBLAS.AbstractGBArray{T, F, O, 2} where {T, F, O}}}
202-
return *(G(D), A)
202+
return *(A, G(D), op; mask, accum, desc)
203203
end

0 commit comments

Comments
 (0)