Skip to content

Commit a0c8344

Browse files
author
Wimmerer
committed
with function/ctxvarx, infer output type better
1 parent 96b7296 commit a0c8344

File tree

13 files changed

+204
-164
lines changed

13 files changed

+204
-164
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
88
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
11+
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
1112
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

src/SuiteSparseGraphBLAS.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MacroTools
77
using LinearAlgebra
88
using Random: randsubseq, default_rng, AbstractRNG, GLOBAL_RNG
99
using CEnum
10-
10+
using ContextVariablesX
1111
include("abstracts.jl")
1212
include("libutils.jl")
1313
include("lib/LibGraphBLAS.jl")
@@ -57,6 +57,18 @@ const MonoidBinaryOrRig = Union{
5757
AbstractBinaryOp,
5858
AbstractMonoid
5959
}
60+
61+
const OperatorUnion = Union{
62+
AbstractOp,
63+
GrBOp
64+
}
65+
66+
@contextvar ctxop::OperatorUnion
67+
@contextvar ctxmask::Union{GBArray, Ptr} = C_NULL
68+
@contextvar ctxaccum::Union{BinaryUnion, Ptr} = C_NULL
69+
@contextvar ctxdesc::Descriptor
70+
include("with.jl")
71+
6072
include("scalar.jl")
6173
include("vector.jl")
6274
include("matrix.jl")
@@ -86,12 +98,13 @@ export GBScalar, GBVector, GBMatrix #arrays
8698
export clear!, extract, extract!, subassign!, assign! #array functions
8799

88100
#operations
89-
export mul, select, select!, eadd, eadd!, emul, emul!, apply, apply!, gbtranspose, gbtranspose!
90-
export multiply
101+
export mul, select, select!, eadd, eadd!, emul, emul!, map, map!, gbtranspose, gbtranspose!
91102
# Reexports.
92103
export diag, Diagonal, mul!, kron, kron!, transpose
93104
export nnz, sprand, findnz
94105

106+
#with function
107+
export with
95108

96109
function __init__()
97110
_load_globaltypes()
@@ -107,4 +120,8 @@ function __init__()
107120
libgb.GrB_finalize()
108121
end
109122
end
123+
124+
#We need to do this after __init__
125+
126+
110127
end #end of module

src/operations/ewise.jl

Lines changed: 52 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ union equivalent see [`eadd!`](@ref).
1919
- `mask::Union{Ptr{Nothing}, GBMatrix} = C_NULL`: optional mask.
2020
- `accum::Union{Ptr{Nothing}, AbstractBinaryOp} = C_NULL`: binary accumulator operation
2121
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
22-
- `desc::Descriptor = Descriptors.NULL`
22+
- `desc = nothing`
2323
"""
2424
function emul! end
2525

@@ -43,7 +43,7 @@ union equivalent see [`eadd`](@ref).
4343
- `mask::Union{Ptr{Nothing}, GBMatrix} = C_NULL`: optional mask.
4444
- `accum::Union{Ptr{Nothing}, AbstractBinaryOp} = C_NULL`: binary accumulator operation
4545
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
46-
- `desc::Descriptor = Descriptors.NULL`
46+
- `desc = nothing`
4747
4848
# Returns
4949
- `GBArray`: Output `GBVector` or `GBMatrix` whose eltype is determined by the `eltype` of
@@ -55,11 +55,13 @@ function emul!(
5555
w::GBVector,
5656
u::GBVector,
5757
v::GBVector,
58-
op::MonoidBinaryOrRig = BinaryOps.TIMES;
59-
mask = C_NULL,
60-
accum = C_NULL,
61-
desc::Descriptor = Descriptors.NULL
58+
op = nothing;
59+
mask = nothing,
60+
accum = nothing,
61+
desc = nothing
6262
)
63+
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
64+
6365
size(w) == size(u) == size(v) || throw(DimensionMismatch())
6466
op = getoperator(op, optype(u, v))
6567
accum = getoperator(accum, eltype(w))
@@ -72,17 +74,20 @@ function emul!(
7274
elseif op isa libgb.GrB_BinaryOp
7375
libgb.GrB_Vector_eWiseMult_BinaryOp(w, mask, accum, op, u, v, desc)
7476
return w
77+
else
78+
throw(ArgumentError("$op is not a valid monoid binary op or semiring."))
7579
end
7680
end
7781

7882
function emul(
7983
u::GBVector,
8084
v::GBVector,
81-
op::MonoidBinaryOrRig = BinaryOps.TIMES;
82-
mask = C_NULL,
83-
accum = C_NULL,
84-
desc::Descriptor = Descriptors.NULL
85+
op = nothing;
86+
mask = nothing,
87+
accum = nothing,
88+
desc = nothing
8589
)
90+
op = _handlectx(op, ctxop, BinaryOps.TIMES)
8691
if op isa GrBOp
8792
t = ztype(op)
8893
else
@@ -96,11 +101,12 @@ function emul!(
96101
C::GBMatrix,
97102
A::GBMatOrTranspose,
98103
B::GBMatOrTranspose,
99-
op::MonoidBinaryOrRig = BinaryOps.TIMES;
100-
mask = C_NULL,
101-
accum = C_NULL,
102-
desc::Descriptor = Descriptors.NULL
104+
op = nothing;
105+
mask = nothing,
106+
accum = nothing,
107+
desc = nothing
103108
)
109+
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
104110
size(C) == size(A) == size(B) || throw(DimensionMismatch())
105111
A, desc, B = _handletranspose(A, desc, B)
106112
op = getoperator(op, optype(A, B))
@@ -114,17 +120,20 @@ function emul!(
114120
elseif op isa libgb.GrB_BinaryOp
115121
libgb.GrB_Matrix_eWiseMult_BinaryOp(C, mask, accum, op, A, B, desc)
116122
return C
123+
else
124+
throw(ArgumentError("$op is not a valid monoid binary op or semiring."))
117125
end
118126
end
119127

120128
function emul(
121129
A::GBMatOrTranspose,
122130
B::GBMatOrTranspose,
123-
op::MonoidBinaryOrRig = BinaryOps.TIMES;
124-
mask = C_NULL,
125-
accum = C_NULL,
126-
desc::Descriptor = Descriptors.NULL
131+
op = nothing;
132+
mask = nothing,
133+
accum = nothing,
134+
desc = nothing
127135
)
136+
op = _handlectx(op, ctxop, BinaryOps.TIMES)
128137
if op isa GrBOp
129138
t = ztype(op)
130139
else
@@ -155,7 +164,7 @@ intersection equivalent see [`eadd!`](@ref).
155164
- `mask::Union{Ptr{Nothing}, GBMatrix} = C_NULL`: optional mask.
156165
- `accum::Union{Ptr{Nothing}, AbstractBinaryOp} = C_NULL`: binary accumulator operation
157166
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
158-
- `desc::Descriptor = Descriptors.NULL`
167+
- `desc = nothing`
159168
"""
160169
function eadd! end
161170

@@ -179,7 +188,7 @@ intersection equivalent see [`emul`](@ref).
179188
- `mask::Union{Ptr{Nothing}, GBMatrix} = C_NULL`: optional mask.
180189
- `accum::Union{Ptr{Nothing}, AbstractBinaryOp} = C_NULL`: binary accumulator operation
181190
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
182-
- `desc::Descriptor = Descriptors.NULL`
191+
- `desc = nothing`
183192
184193
# Returns
185194
- `GBArray`: Output `GBVector` or `GBMatrix` whose eltype is determined by the `eltype` of
@@ -191,11 +200,12 @@ function eadd!(
191200
w::GBVector,
192201
u::GBVector,
193202
v::GBVector,
194-
op::MonoidBinaryOrRig = BinaryOps.PLUS;
195-
mask = C_NULL,
196-
accum = C_NULL,
197-
desc::Descriptor = Descriptors.NULL
203+
op = nothing;
204+
mask = nothing,
205+
accum = nothing,
206+
desc = nothing
198207
)
208+
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
199209
size(w) == size(u) == size(v) || throw(DimensionMismatch())
200210
op = getoperator(op, optype(u, v))
201211
accum = getoperator(accum, eltype(w))
@@ -208,17 +218,20 @@ function eadd!(
208218
elseif op isa libgb.GrB_BinaryOp
209219
libgb.GrB_Vector_eWiseAdd_BinaryOp(w, mask, accum, op, u, v, desc)
210220
return w
221+
else
222+
throw(ArgumentError("$op is not a valid monoid binary op or semiring."))
211223
end
212224
end
213225

214226
function eadd(
215227
u::GBVector,
216228
v::GBVector,
217-
op::MonoidBinaryOrRig = BinaryOps.PLUS;
218-
mask = C_NULL,
219-
accum = C_NULL,
220-
desc::Descriptor = Descriptors.NULL
229+
op = nothing;
230+
mask = nothing,
231+
accum = nothing,
232+
desc = nothing
221233
)
234+
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
222235
if op isa GrBOp
223236
t = ztype(op)
224237
else
@@ -232,11 +245,12 @@ function eadd!(
232245
C::GBMatrix,
233246
A::GBMatOrTranspose,
234247
B::GBMatOrTranspose,
235-
op::MonoidBinaryOrRig = BinaryOps.PLUS;
236-
mask = C_NULL,
237-
accum = C_NULL,
238-
desc::Descriptor = Descriptors.NULL
248+
op = nothing;
249+
mask = nothing,
250+
accum = nothing,
251+
desc = nothing
239252
)
253+
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
240254
size(C) == size(A) == size(B) || throw(DimensionMismatch())
241255
A, desc, B = _handletranspose(A, desc, B)
242256
op = getoperator(op, optype(A, B))
@@ -251,18 +265,19 @@ function eadd!(
251265
libgb.GrB_Matrix_eWiseAdd_BinaryOp(C, mask, accum, op, A, B, desc)
252266
return C
253267
else
254-
error("Unreachable")
268+
throw(ArgumentError("$op is not a valid monoid binary op or semiring."))
255269
end
256270
end
257271

258272
function eadd(
259273
A::GBMatOrTranspose,
260274
B::GBMatOrTranspose,
261-
op::MonoidBinaryOrRig = BinaryOps.PLUS;
262-
mask = C_NULL,
263-
accum = C_NULL,
264-
desc::Descriptor = Descriptors.NULL
275+
op = nothing;
276+
mask = nothing,
277+
accum = nothing,
278+
desc = nothing
265279
)
280+
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
266281
if op isa GrBOp
267282
t = ztype(op)
268283
else
@@ -271,57 +286,3 @@ function eadd(
271286
C = GBMatrix{t}(size(A))
272287
return eadd!(C, A, B, op; mask, accum, desc)
273288
end
274-
275-
# Note well: `.*` and `.+` have clear counterparts in the language of GraphBLAS:
276-
# edgewiseAdd and edgewiseMul. These do not necessarily have the same semantics though.
277-
# edgewiseAdd and edgewiseMul might better be described as edgewiseUnion and
278-
# edgewiseIntersection respectively, and then `op` is applied at materialized indices.
279-
#
280-
# So the plan is thus: `.*` and `.+` will have the Union and Intersection semantics *with*
281-
# the default ops of `*` and `+` respectively. *However*, they have `op` kwargs, which
282-
# may be used with a macro later on down the line to override the default ops.
283-
function Base.broadcasted(
284-
::typeof(+),
285-
u::GBVector,
286-
v::GBVector,
287-
op::MonoidBinaryOrRig = BinaryOps.PLUS;
288-
mask = C_NULL,
289-
accum = C_NULL,
290-
desc::Descriptor = Descriptors.NULL
291-
)
292-
return eadd(u, v, op; mask, accum, desc)
293-
end
294-
function Base.broadcasted(
295-
::typeof(*),
296-
u::GBVector,
297-
v::GBVector,
298-
op::MonoidBinaryOrRig = BinaryOps.TIMES;
299-
mask = C_NULL,
300-
accum = C_NULL,
301-
desc::Descriptor = Descriptors.NULL
302-
)
303-
return emul(u, v, op; mask, accum, desc)
304-
end
305-
306-
function Base.broadcasted(
307-
::typeof(+),
308-
A::GBMatOrTranspose,
309-
B::GBMatOrTranspose,
310-
op::MonoidBinaryOrRig = BinaryOps.PLUS;
311-
mask = C_NULL,
312-
accum = C_NULL,
313-
desc::Descriptor = Descriptors.NULL
314-
)
315-
return eadd(A, B, op; mask, accum, desc)
316-
end
317-
function Base.broadcasted(
318-
::typeof(*),
319-
A::GBMatOrTranspose,
320-
B::GBMatOrTranspose,
321-
op::MonoidBinaryOrRig = BinaryOps.PLUS;
322-
mask = C_NULL,
323-
accum = C_NULL,
324-
desc::Descriptor = Descriptors.NULL
325-
)
326-
return emul(A, B, op; mask, accum, desc)
327-
end

src/operations/kronecker.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,40 @@ function LinearAlgebra.kron!(
22
C::GBMatOrTranspose,
33
A::GBMatOrTranspose,
44
B::GBMatOrTranspose,
5-
op::MonoidBinaryOrRig = BinaryOps.TIMES;
6-
mask = C_NULL,
7-
accum = C_NULL,
8-
desc::Descriptor = Descriptors.NULL
5+
op = nothing;
6+
mask = nothing,
7+
accum = nothing,
8+
desc = nothing
99
)
10+
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)
1011
op = getoperator(op, optype(A, B))
1112
A, desc, B = _handletranspose(A, desc, B)
1213
accum = getoperator(accum, eltype(C))
13-
if op isa Union{AbstractBinaryOp, libgb.GrB_BinaryOp}
14+
if op isa libgb.GrB_BinaryOp
1415
libgb.GxB_kron(C, mask, accum, op, A, B, desc)
15-
elseif op isa Union{AbstractMonoid, libgb.GrB_Monoid}
16+
elseif op isa libgb.GrB_Monoid
1617
libgb.GrB_Matrix_kronecker_Monoid(C, mask, accum, op, A, B, desc)
17-
elseif op isa Union{AbstractSemiring, libgb.GrB_Semiring}
18+
elseif op isa libgb.GrB_Semiring
1819
libgb.GrB_Matrix_kronecker_Semiring(C, mask, accum, op, A, B, desc)
20+
else
21+
throw(ArgumentError("$op is not a valid monoid binary op or semiring."))
1922
end
2023
end
2124

2225
function LinearAlgebra.kron(
2326
A::GBMatOrTranspose,
2427
B::GBMatOrTranspose,
25-
op::MonoidBinaryOrRig = BinaryOps.TIMES;
26-
mask = C_NULL,
27-
accum = C_NULL,
28-
desc::Descriptor = Descriptors.NULL
28+
op = nothing;
29+
mask = nothing,
30+
accum = nothing,
31+
desc = nothing
2932
)
30-
t = optype(A, B)
33+
op = _handlectx(op, ctxop, BinaryOps.TIMES)
34+
if op isa GrBOp
35+
t = ztype(op)
36+
else
37+
t = optype(A, B)
38+
end
3139
C = GBMatrix{t}(size(A,1) * size(B, 1), size(A, 2) * size(B, 2))
3240
kron!(C, A, B, op; mask, accum, desc)
3341
return C

0 commit comments

Comments
 (0)