Skip to content

Commit ad5d673

Browse files
committed
add idxunaryops back finally
1 parent 075d731 commit ad5d673

File tree

9 files changed

+459
-229
lines changed

9 files changed

+459
-229
lines changed

src/SuiteSparseGraphBLAS.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ include("operators/libgbops.jl")
3838

3939
include("gbtypes.jl")
4040
include("types.jl")
41+
include("scalar.jl")
4142
include("mem.jl")
4243

4344

@@ -55,11 +56,11 @@ using .UnaryOps
5556
using .BinaryOps
5657
using .Monoids
5758
using .Semirings
59+
using .IndexUnaryOps
5860

5961
include("indexutils.jl")
6062
#
6163
include("operations/extract.jl")
62-
include("scalar.jl")
6364
include("gbvector.jl")
6465
include("gbmatrix.jl")
6566
include("abstractgbarray.jl")
@@ -159,8 +160,6 @@ function __init__()
159160
for type valid_vec
160161
Base.unsafe_convert(LibGraphBLAS.GrB_Type, gbtype(type))
161162
end
162-
# Eagerly load selectops constants.
163-
_loadselectops()
164163
ALL.p = load_global("GrB_ALL", LibGraphBLAS.GrB_Index)
165164
# Set printing done by SuiteSparse:GraphBLAS to base-1 rather than base-0.
166165
gbset(BASE1, 1)

src/chainrules/selectrules.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
function frule(
33
(_, _, ΔA)::Tuple,
44
::typeof(select),
5-
op::Union{Function, SelectUnion},
5+
op::Function,
66
A::AbstractGBArray
77
)
88
Ω = select(op, A)
@@ -14,7 +14,7 @@ end
1414
function frule(
1515
(_, _, ΔA, _)::Tuple,
1616
::typeof(select),
17-
op::Union{Function, SelectUnion},
17+
op::Function,
1818
A::AbstractGBArray,
1919
thunk::Union{GBScalar, Nothing, valid_union}
2020
)
@@ -25,7 +25,7 @@ end
2525

2626
function rrule(
2727
::typeof(select),
28-
op::Union{Function, SelectUnion},
28+
op::Function,
2929
A::AbstractGBArray
3030
)
3131
out = select(op, A)
@@ -38,7 +38,7 @@ end
3838

3939
function rrule(
4040
::typeof(select),
41-
op::Union{Function, SelectUnion},
41+
op::Function,
4242
A::AbstractGBArray,
4343
thunk::Union{GBScalar, Nothing, valid_union}
4444
)

src/operations/select.jl

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,56 @@
11
# TODO: update to modern op system.
22

3+
function defaultselectthunk(op, T)
4+
if op (rowindex, colindex, diagindex)
5+
return one(Int64)
6+
elseif op (tril, triu, diag, offdiag)
7+
return zero(Int64)
8+
elseif op === ==
9+
return zero(T)
10+
else
11+
throw(ArgumentError("You must pass `thunk` to select for this function."))
12+
end
13+
end
14+
315
"In place version of `select`."
416
function select!(
517
op,
618
C::GBVecOrMat,
7-
A::GBArrayOrTranspose,
8-
thunk = nothing;
19+
A::GBArrayOrTranspose{T},
20+
thunk::TH = defaultselectthunk(op, T);
921
mask = nothing,
1022
accum = nothing,
1123
desc = nothing
12-
)
24+
) where {T, TH}
25+
op (rowindex, colindex, diagindex, tril, triu, diag, offdiag) &&
26+
(thunk = convert(Int64, thunk))
1327
_canbeoutput(C) || throw(ShallowException())
14-
op = SelectOp(op)
28+
op = indexunaryop(op, T, TH)
1529
desc = _handledescriptor(desc; out=C, in1=A)
1630
mask = _handlemask!(desc, mask)
17-
thunk === nothing && (thunk = C_NULL)
1831
accum = _handleaccum(accum, storedeltype(C))
19-
if thunk isa Number
20-
thunk = GBScalar(thunk)
21-
end
22-
@wraperror LibGraphBLAS.GxB_Matrix_select(C, mask, accum, op, parent(A), thunk, desc)
32+
@wraperror LibGraphBLAS.GrB_Matrix_select_Scalar(C, mask, accum, op, parent(A), GBScalar(thunk), desc)
2333
return C
2434
end
2535

26-
function select!(op, A::GBArrayOrTranspose, thunk = nothing; mask = nothing, accum = nothing, desc = nothing)
36+
function select!(
37+
op, A::GBArrayOrTranspose{T}, thunk = defaultselectthunk(op, T);
38+
mask = nothing, accum = nothing, desc = nothing
39+
) where T
2740
return select!(op, A, A, thunk; mask, accum, desc)
2841
end
2942

3043
"""
31-
select(op::Union{Function, SelectUnion}, A::GBArrayOrTranspose; kwargs...)::GBArrayOrTranspose
32-
select(op::Union{Function, SelectUnion}, A::GBArrayOrTranspose, thunk; kwargs...)::GBArrayOrTranspose
44+
select(op::Function, A::GBArrayOrTranspose; kwargs...)::GBArrayOrTranspose
45+
select(op::Function, A::GBArrayOrTranspose, thunk; kwargs...)::GBArrayOrTranspose
3346
3447
Return a `GBArray` whose elements satisfy the predicate defined by `op`.
3548
Some SelectOps or functions may require an additional argument `thunk`, for use in
3649
comparison operations such as `C[i,j] = A[i,j] >= thunk ? A[i,j] : nothing`, which is
3750
performed by `select(>, A, thunk)`.
3851
3952
# Arguments
40-
- `op::Union{Function, SelectUnion}`: A select operator from the SelectOps submodule.
53+
- `op::Function`: A select operator from the SelectOps submodule.
4154
- `A::GBArrayOrTranspose`
4255
- `thunk::Union{GBScalar, nothing, valid_union}`: Optional value used to evaluate `op`.
4356
@@ -53,19 +66,19 @@ Some SelectOps or functions may require an additional argument `thunk`, for use
5366
"""
5467
function select(
5568
op,
56-
A::GBArrayOrTranspose,
57-
thunk = nothing;
69+
A::GBArrayOrTranspose{T},
70+
thunk::TH = defaultselectthunk(op, T);
5871
mask = nothing,
5972
accum = nothing,
6073
desc = nothing
61-
)
62-
op = SelectOp(op)
63-
C = similar(A)
74+
) where {T, TH}
75+
op = indexunaryop(op, T, TH)
76+
C = similar(A) # we keep the same type!! not the ztype of op.
6477
select!(op, C, A, thunk; accum, mask, desc)
6578
return C
6679
end
6780

6881
LinearAlgebra.tril(A::GBArrayOrTranspose, k::Integer = 0) = select(tril, A, k)
6982
LinearAlgebra.triu(A::GBArrayOrTranspose, k::Integer = 0) = select(triu, A, k)
70-
SparseArrays.dropzeros(A::GBArrayOrTranspose) = select(nonzeros, A)
71-
SparseArrays.dropzeros!(A::GBArrayOrTranspose) = select!(nonzeros, A)
83+
SparseArrays.dropzeros(A::GBArrayOrTranspose{T}) where T = select(!=, A, zero(T))
84+
SparseArrays.dropzeros!(A::GBArrayOrTranspose{T}) where T = select!(!=, A, zero(T))

src/operators/binaryops.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ end
2525
SuiteSparseGraphBLAS.binaryop(f::F, ::Type{X}, ::Type{X}) where {F, X} = fallback_binaryop(f, X, X)
2626

2727
function SuiteSparseGraphBLAS.binaryop(
28-
f::F, ::Type{X}, ::Type{Y}
29-
) where {F, X, Y}
28+
f, ::Type{X}, ::Type{Y}
29+
) where {X, Y}
3030
P = promote_type(X, Y)
3131
if isconcretetype(P) && (X <: valid_union && Y <: valid_union)
3232
return binaryop(f, P, P)
@@ -41,8 +41,8 @@ SuiteSparseGraphBLAS.binaryop(f, ::GBArrayOrTranspose{T}, ::Type{U}) where {T, U
4141
SuiteSparseGraphBLAS.binaryop(f, ::Type{T}, ::GBArrayOrTranspose{U}) where {T, U} = binaryop(f, T, U)
4242

4343
SuiteSparseGraphBLAS.binaryop(f, type) = binaryop(f, type, type)
44-
SuiteSparseGraphBLAS.binaryop(op::TypedBinaryOperator, x...) = op
45-
44+
SuiteSparseGraphBLAS.binaryop(op::TypedBinaryOperator, ::Type{X}, ::Type{Y}) where {X, Y} = op
45+
SuiteSparseGraphBLAS.binaryop(op::TypedBinaryOperator, ::Type{X}) where X = op
4646
SuiteSparseGraphBLAS.juliaop(op::TypedBinaryOperator) = op.fn
4747

4848
# TODO, clean up this function, it allocates typedop and is otherwise perhaps a little slow.

0 commit comments

Comments
 (0)