Skip to content

Commit 12d538e

Browse files
author
Wimmerer
committed
op to positional, precomp struct generation
1 parent c1374c6 commit 12d538e

File tree

6 files changed

+63
-55
lines changed

6 files changed

+63
-55
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ version = "0.4.0"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
8+
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
9+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
10+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
811
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
912
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1013
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

src/SuiteSparseGraphBLAS.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ include("operators/monoids.jl")
2626
include("operators/semirings.jl")
2727
include("operators/selectops.jl")
2828

29+
_createunaryops()
30+
_createbinaryops()
31+
_createmonoids()
32+
_createsemirings()
33+
2934
include("descriptors.jl")
3035

3136
include("indexutils.jl")
@@ -74,6 +79,7 @@ include("export.jl")
7479
include("options.jl")
7580
#include("random.jl")
7681
include("misc.jl")
82+
include("chainrules/mulrules.jl")
7783
export libgb
7884
export UnaryOps, BinaryOps, Monoids, Semirings, SelectOps, Descriptors #Submodules
7985
export xtype, ytype, ztype
@@ -82,17 +88,13 @@ export clear!, extract, extract!, subassign!, assign! #array functions
8288

8389
#operations
8490
export mul, select, select!, eadd, eadd!, emul, emul!, apply, apply!, gbtranspose, gbtranspose!
85-
91+
export multiply
8692
# Reexports.
8793
export diag, Diagonal, mul!, kron, kron!, transpose
8894
export nnz, sprand, findnz
8995

9096

9197
function __init__()
92-
_createunaryops()
93-
_createbinaryops()
94-
_createmonoids()
95-
_createsemirings()
9698
_load_globaltypes()
9799
_loadselectops()
98100
_loaddescriptors()

src/operations/ewise.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ function emul end
5454
function emul!(
5555
w::GBVector,
5656
u::GBVector,
57-
v::GBVector;
58-
op::MonoidBinaryOrRig = BinaryOps.TIMES,
57+
v::GBVector,
58+
op::MonoidBinaryOrRig = BinaryOps.TIMES;
5959
mask = C_NULL,
6060
accum = C_NULL,
6161
desc::Descriptor = Descriptors.NULL
@@ -77,8 +77,8 @@ end
7777

7878
function emul(
7979
u::GBVector,
80-
v::GBVector;
81-
op::MonoidBinaryOrRig = BinaryOps.TIMES,
80+
v::GBVector,
81+
op::MonoidBinaryOrRig = BinaryOps.TIMES;
8282
mask = C_NULL,
8383
accum = C_NULL,
8484
desc::Descriptor = Descriptors.NULL
@@ -89,14 +89,14 @@ function emul(
8989
t = optype(u, v)
9090
end
9191
w = GBVector{t}(size(u))
92-
return emul!(w, u, v; op, mask , accum, desc)
92+
return emul!(w, u, v, op; mask , accum, desc)
9393
end
9494

9595
function emul!(
9696
C::GBMatrix,
9797
A::GBMatOrTranspose,
98-
B::GBMatOrTranspose;
99-
op::MonoidBinaryOrRig = BinaryOps.TIMES,
98+
B::GBMatOrTranspose,
99+
op::MonoidBinaryOrRig = BinaryOps.TIMES;
100100
mask = C_NULL,
101101
accum = C_NULL,
102102
desc::Descriptor = Descriptors.NULL
@@ -119,8 +119,8 @@ end
119119

120120
function emul(
121121
A::GBMatOrTranspose,
122-
B::GBMatOrTranspose;
123-
op::MonoidBinaryOrRig = BinaryOps.TIMES,
122+
B::GBMatOrTranspose,
123+
op::MonoidBinaryOrRig = BinaryOps.TIMES;
124124
mask = C_NULL,
125125
accum = C_NULL,
126126
desc::Descriptor = Descriptors.NULL
@@ -131,7 +131,7 @@ function emul(
131131
t = optype(A, B)
132132
end
133133
C = GBMatrix{t}(size(A))
134-
return emul!(C, A, B; op, mask, accum, desc)
134+
return emul!(C, A, B, op; mask, accum, desc)
135135
end
136136

137137
"""
@@ -190,8 +190,8 @@ function eadd end
190190
function eadd!(
191191
w::GBVector,
192192
u::GBVector,
193-
v::GBVector;
194-
op::MonoidBinaryOrRig = BinaryOps.PLUS,
193+
v::GBVector,
194+
op::MonoidBinaryOrRig = BinaryOps.PLUS;
195195
mask = C_NULL,
196196
accum = C_NULL,
197197
desc::Descriptor = Descriptors.NULL
@@ -213,8 +213,8 @@ end
213213

214214
function eadd(
215215
u::GBVector,
216-
v::GBVector;
217-
op::MonoidBinaryOrRig = BinaryOps.PLUS,
216+
v::GBVector,
217+
op::MonoidBinaryOrRig = BinaryOps.PLUS;
218218
mask = C_NULL,
219219
accum = C_NULL,
220220
desc::Descriptor = Descriptors.NULL
@@ -225,14 +225,14 @@ function eadd(
225225
t = optype(eltype(u), eltype(v))
226226
end
227227
w = GBVector{t}(size(u))
228-
return eadd!(w, u, v; op, mask, accum, desc)
228+
return eadd!(w, u, v, op; mask, accum, desc)
229229
end
230230

231231
function eadd!(
232232
C::GBMatrix,
233233
A::GBMatOrTranspose,
234-
B::GBMatOrTranspose;
235-
op::MonoidBinaryOrRig = BinaryOps.PLUS,
234+
B::GBMatOrTranspose,
235+
op::MonoidBinaryOrRig = BinaryOps.PLUS;
236236
mask = C_NULL,
237237
accum = C_NULL,
238238
desc::Descriptor = Descriptors.NULL
@@ -257,8 +257,8 @@ end
257257

258258
function eadd(
259259
A::GBMatOrTranspose,
260-
B::GBMatOrTranspose;
261-
op::MonoidBinaryOrRig = BinaryOps.PLUS,
260+
B::GBMatOrTranspose,
261+
op::MonoidBinaryOrRig = BinaryOps.PLUS;
262262
mask = C_NULL,
263263
accum = C_NULL,
264264
desc::Descriptor = Descriptors.NULL
@@ -269,7 +269,7 @@ function eadd(
269269
t = optype(A, B)
270270
end
271271
C = GBMatrix{t}(size(A))
272-
return eadd!(C, A, B; op, mask, accum, desc)
272+
return eadd!(C, A, B, op; mask, accum, desc)
273273
end
274274

275275
# Note well: `.*` and `.+` have clear counterparts in the language of GraphBLAS:
@@ -283,45 +283,45 @@ end
283283
function Base.broadcasted(
284284
::typeof(+),
285285
u::GBVector,
286-
v::GBVector;
287-
op::MonoidBinaryOrRig = BinaryOps.PLUS,
286+
v::GBVector,
287+
op::MonoidBinaryOrRig = BinaryOps.PLUS;
288288
mask = C_NULL,
289289
accum = C_NULL,
290290
desc::Descriptor = Descriptors.NULL
291291
)
292-
return eadd(u, v; op, mask, accum, desc)
292+
return eadd(u, v, op; mask, accum, desc)
293293
end
294294
function Base.broadcasted(
295295
::typeof(*),
296296
u::GBVector,
297-
v::GBVector;
298-
op::MonoidBinaryOrRig = BinaryOps.TIMES,
297+
v::GBVector,
298+
op::MonoidBinaryOrRig = BinaryOps.TIMES;
299299
mask = C_NULL,
300300
accum = C_NULL,
301301
desc::Descriptor = Descriptors.NULL
302302
)
303-
return emul(u, v; op, mask, accum, desc)
303+
return emul(u, v, op; mask, accum, desc)
304304
end
305305

306306
function Base.broadcasted(
307307
::typeof(+),
308308
A::GBMatOrTranspose,
309-
B::GBMatOrTranspose;
310-
op::MonoidBinaryOrRig = BinaryOps.PLUS,
309+
B::GBMatOrTranspose,
310+
op::MonoidBinaryOrRig = BinaryOps.PLUS;
311311
mask = C_NULL,
312312
accum = C_NULL,
313313
desc::Descriptor = Descriptors.NULL
314314
)
315-
return eadd(A, B; op, mask, accum, desc)
315+
return eadd(A, B, op; mask, accum, desc)
316316
end
317317
function Base.broadcasted(
318318
::typeof(*),
319319
A::GBMatOrTranspose,
320-
B::GBMatOrTranspose;
321-
op::MonoidBinaryOrRig = BinaryOps.PLUS,
320+
B::GBMatOrTranspose,
321+
op::MonoidBinaryOrRig = BinaryOps.PLUS;
322322
mask = C_NULL,
323323
accum = C_NULL,
324324
desc::Descriptor = Descriptors.NULL
325325
)
326-
return emul(A, B; op, mask, accum, desc)
326+
return emul(A, B, op; mask, accum, desc)
327327
end

src/operations/kronecker.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
function LinearAlgebra.kron!(
22
C::GBMatOrTranspose,
33
A::GBMatOrTranspose,
4-
B::GBMatOrTranspose;
5-
op::MonoidBinaryOrRig = BinaryOps.TIMES,
4+
B::GBMatOrTranspose,
5+
op::MonoidBinaryOrRig = BinaryOps.TIMES;
66
mask = C_NULL,
77
accum = C_NULL,
88
desc::Descriptor = Descriptors.NULL
@@ -21,14 +21,14 @@ end
2121

2222
function LinearAlgebra.kron(
2323
A::GBMatOrTranspose,
24-
B::GBMatOrTranspose;
25-
op::MonoidBinaryOrRig = BinaryOps.TIMES,
24+
B::GBMatOrTranspose,
25+
op::MonoidBinaryOrRig = BinaryOps.TIMES;
2626
mask = C_NULL,
2727
accum = C_NULL,
2828
desc::Descriptor = Descriptors.NULL
2929
)
3030
t = optype(A, B)
3131
C = GBMatrix{t}(size(A,1) * size(B, 1), size(A, 2) * size(B, 2))
32-
kron!(C, A, B; op, mask, accum, desc)
32+
kron!(C, A, B, op; mask, accum, desc)
3333
return C
3434
end

src/operations/map.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function Base.map!(
7373
op::BinaryUnion, A::GBArray, x;
7474
mask = C_NULL, accum = C_NULL, desc::Descriptor = Descriptors.NULL
7575
)
76-
return map!(op,A, A, x; mask, accum, desc)
76+
return map!(op, A, A, x; mask, accum, desc)
7777
end
7878

7979
function Base.map(

src/operations/mul.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
function LinearAlgebra.mul!(
22
C::GBMatrix,
33
A::GBMatOrTranspose,
4-
B::GBMatOrTranspose;
5-
op::SemiringUnion = Semirings.PLUS_TIMES,
4+
B::GBMatOrTranspose,
5+
op::SemiringUnion = Semirings.PLUS_TIMES;
66
mask = C_NULL,
77
accum = C_NULL,
88
desc::Descriptor = Descriptors.NULL
@@ -20,8 +20,8 @@ end
2020
function LinearAlgebra.mul!(
2121
w::GBVector,
2222
u::GBVector,
23-
A::GBMatOrTranspose;
24-
op::SemiringUnion = Semirings.PLUS_TIMES,
23+
A::GBMatOrTranspose,
24+
op::SemiringUnion = Semirings.PLUS_TIMES;
2525
mask = C_NULl,
2626
accum = C_NULL,
2727
desc::Descriptor = Descriptors.NULL
@@ -38,8 +38,8 @@ end
3838
function LinearAlgebra.mul!(
3939
w::GBVector,
4040
A::GBMatOrTranspose,
41-
u::GBVector;
42-
op::SemiringUnion = Semirings.PLUS_TIMES,
41+
u::GBVector,
42+
op::SemiringUnion = Semirings.PLUS_TIMES;
4343
mask = C_NULL,
4444
accum = C_NULL,
4545
desc::Descriptor = Descriptors.NULL
@@ -78,10 +78,13 @@ The default semiring is the `+.*` semiring.
7878
- `GBArray`: The output matrix whose `eltype` is determined by `A` and `B` or the semiring
7979
if a type specific semiring is provided.
8080
"""
81+
function multiply(A::GBArray, B::GBArray)
82+
return mul(A, B, Semirings.PLUS_TIMES)
83+
end
8184
function mul(
8285
A::GBArray,
83-
B::GBArray;
84-
op::SemiringUnion = Semirings.PLUS_TIMES,
86+
B::GBArray,
87+
op::SemiringUnion = Semirings.PLUS_TIMES;
8588
mask = C_NULL,
8689
accum = C_NULL,
8790
desc::Descriptor = Descriptors.NULL
@@ -96,17 +99,17 @@ function mul(
9699
else
97100
throw(ArgumentError("Cannot multiply A::GBVector, B::GBVector. Try emul"))
98101
end
99-
mul!(C, A, B; op, mask, accum, desc)
102+
mul!(C, A, B, op; mask, accum, desc)
100103
return C
101104
end
102105

103106
function Base.:*(
104107
A::GBArray,
105-
B::GBArray;
106-
op::SemiringUnion = Semirings.PLUS_TIMES,
108+
B::GBArray,
109+
op::SemiringUnion = Semirings.PLUS_TIMES;
107110
mask = C_NULL,
108111
accum = C_NULL,
109112
desc::Descriptor = Descriptors.NULL
110113
)
111-
mul(A, B; op, mask, accum, desc)
114+
mul(A, B, op; mask, accum, desc)
112115
end

0 commit comments

Comments
 (0)