Skip to content

Commit 1d4eeb9

Browse files
committed
bcasting tests, correct isiso usage
1 parent f72a06b commit 1d4eeb9

File tree

6 files changed

+67
-15
lines changed

6 files changed

+67
-15
lines changed

src/abstractgbarray.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,16 @@ function Base.setindex!(
681681
)
682682
subassign!(C, A, I, J; mask, accum, desc)
683683
end
684+
function Base.setindex!(
685+
C::AbstractGBMatrix,
686+
A,
687+
::Colon;
688+
mask = nothing,
689+
accum = nothing,
690+
desc = nothing
691+
)
692+
subassign!(C, A, :, :; mask, accum, desc)
693+
end
684694

685695
# AbstractGBVector functions:
686696
#############################

src/operations/broadcasts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ Base.Broadcast.broadcasted(::Type{T}, A::AbstractGBArray) where T = LinearAlgebr
356356
# This is overly verbose, perhaps a macro?
357357
# return an operator that swaps the order of the operands.
358358
# * -> *, first -> second, second -> first, - -> rminus, etc.
359-
_swapop(op) = throw(ArgumentError("Cannot swap order of operands automatically. Swap the order of the broadcast statement or overload `_swapop`"))
359+
_swapop(op) = nothing
360360
_swapop(::typeof(first)) = second
361361
_swapop(::typeof(second)) = first
362362

src/operations/ewise.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ function _bcastemul!(
157157
)
158158
op2 = _swapop(op)
159159
if op2 === nothing # manually bcast:
160-
full = similar(B)
160+
full = similar(B, size(A, 2))
161161
full[:] = 0
162162
T = *(B, full', (any, first); mask)
163163
return emul!(C, A, T, op; mask, accum, desc)
@@ -186,9 +186,9 @@ function _bcastemul!(
186186
)
187187
op2 = _swapop(op)
188188
if op2 === nothing
189-
full = similar(A)
189+
full = similar(A, size(B, 1))
190190
full[:] = 0
191-
T = *(A, full', (any, first); mask)
191+
T = *(full, A, (any, second); mask)
192192
return emul!(C, T, B, op; mask, accum, desc)
193193
end
194194
return mul!(C, B, Diagonal(parent(A)), (any, op2); mask, accum, desc)

src/operations/mul.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ function Base.:*(
9595
fill = _promotefill(parent(A), parent(B), op)
9696
if A isa GBMatrixOrTranspose && B isa AbstractGBVector
9797
C = similar(A, T, size(A, 1); fill)
98-
elseif A isa AbstractGBVector && B isa GBMatrixOrTranspose
99-
C = similar(A, T, size(B, 2); fill)
10098
elseif A isa Transpose{<:Any, <:AbstractGBVector} && B isa AbstractGBVector
10199
C = similar(A, T, 1; fill)
102100
else

src/unpack.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function _unpackdensematrix!(
66
desc = _handledescriptor(desc)
77
Csize = Ref{LibGraphBLAS.GrB_Index}(length(A) * sizeof(T))
88
values = Ref{Ptr{Cvoid}}(C_NULL)
9-
isiso = Ref{Bool}(allowiso ? true : C_NULL)
9+
isiso = allowiso ? Ref{Bool}(true) : C_NULL
1010
@wraperror LibGraphBLAS.GxB_Matrix_unpack_FullC(
1111
A,
1212
values,
@@ -32,7 +32,7 @@ function _unpackdensematrix!(
3232
desc = _handledescriptor(desc)
3333
Csize = Ref{LibGraphBLAS.GrB_Index}(length(A) * sizeof(T))
3434
values = Ref{Ptr{Cvoid}}(C_NULL)
35-
isiso = Ref{Bool}(allowiso ? true : C_NULL)
35+
isiso = allowiso ? Ref{Bool}(true) : C_NULL
3636
@wraperror LibGraphBLAS.GxB_Matrix_unpack_FullC(
3737
A,
3838
values,
@@ -60,7 +60,7 @@ function _unpackdensematrixR!(
6060
desc = _handledescriptor(desc)
6161
Csize = Ref{LibGraphBLAS.GrB_Index}(length(A) * sizeof(T))
6262
values = Ref{Ptr{Cvoid}}(C_NULL)
63-
isiso = Ref{Bool}(allowiso ? true : C_NULL)
63+
isiso = allowiso ? Ref{Bool}(true) : C_NULL
6464
@wraperror LibGraphBLAS.GxB_Matrix_unpack_FullR(
6565
A,
6666
values,
@@ -91,7 +91,7 @@ function _unpackcscmatrix!(
9191
colptrsize = Ref{LibGraphBLAS.GrB_Index}()
9292
rowidxsize = Ref{LibGraphBLAS.GrB_Index}()
9393
valsize = Ref{LibGraphBLAS.GrB_Index}()
94-
isiso = Ref{Bool}(allowiso ? true : C_NULL)
94+
isiso = allowiso ? Ref{Bool}(true) : C_NULL
9595
isjumbled = C_NULL
9696
nnonzeros = nnz(A)
9797
@wraperror LibGraphBLAS.GxB_Matrix_unpack_CSC(
@@ -106,6 +106,7 @@ function _unpackcscmatrix!(
106106
isjumbled,
107107
desc
108108
)
109+
isiso == C_NULL && (isiso = false)
109110
colptr = unsafe_wrap(Array, Ptr{Int64}(colptr[]), size(A, 2) + 1)
110111
rowidx = unsafe_wrap(Array, Ptr{Int64}(rowidx[]), nnonzeros)
111112
nstored = isiso[] ? 1 : nnonzeros
@@ -141,7 +142,7 @@ function _unpackcsrmatrix!(
141142
rowptrsize = Ref{LibGraphBLAS.GrB_Index}()
142143
colidxsize = Ref{LibGraphBLAS.GrB_Index}()
143144
valsize = Ref{LibGraphBLAS.GrB_Index}()
144-
isiso = Ref{Bool}(allowiso ? true : C_NULL)
145+
isiso = allowiso ? Ref{Bool}(true) : C_NULL
145146
isjumbled = C_NULL
146147
nnonzeros = nnz(A)
147148
@wraperror LibGraphBLAS.GxB_Matrix_unpack_CSR(
@@ -156,6 +157,7 @@ function _unpackcsrmatrix!(
156157
isjumbled,
157158
desc
158159
)
160+
isiso == C_NULL && (isiso = false)
159161
rowptr = unsafe_wrap(Array, Ptr{Int64}(rowptr[]), size(A, 1) + 1)
160162
colidx = unsafe_wrap(Array, Ptr{Int64}(colidx[]), nnonzeros)
161163
nstored = isiso[] ? 1 : nnonzeros
@@ -190,7 +192,7 @@ function _unpackbitmapmatrix!(
190192
Bsize = Ref{LibGraphBLAS.GrB_Index}(length(A) * sizeof(Bool))
191193
values = Ref{Ptr{Cvoid}}(C_NULL)
192194
bytemap = Ref{Ptr{Int8}}(C_NULL)
193-
isiso = Ref{Bool}(allowiso ? true : C_NULL)
195+
isiso = allowiso ? Ref{Bool}(true) : C_NULL
194196
nnonzeros = Ref{LibGraphBLAS.GrB_Index}(nnz(A))
195197
@wraperror LibGraphBLAS.GxB_Matrix_unpack_BitmapC(
196198
A,
@@ -202,6 +204,7 @@ function _unpackbitmapmatrix!(
202204
nnonzeros,
203205
desc
204206
)
207+
isiso == C_NULL && (isiso = false)
205208
nstored = isiso[] ? 1 : szA
206209
v = unsafe_wrap(Array, Ptr{T}(values[]), nstored)
207210
b = unsafe_wrap(Array, Ptr{Bool}(bytemap[]), szA)
@@ -226,7 +229,7 @@ function _unpackbitmapmatrixR!(
226229
Bsize = Ref{LibGraphBLAS.GrB_Index}(length(A) * sizeof(Int8))
227230
values = Ref{Ptr{Cvoid}}(C_NULL)
228231
bytemap = Ref{Ptr{Int8}}(C_NULL)
229-
isiso = Ref{Bool}(allowiso ? true : C_NULL)
232+
isiso = allowiso ? Ref{Bool}(true) : C_NULL
230233
nonzeros = Ref{LibGraphBLAS.GrB_Index}(0)
231234
@wraperror LibGraphBLAS.GxB_Matrix_unpack_BitmapR(
232235
A,
@@ -238,6 +241,7 @@ function _unpackbitmapmatrixR!(
238241
nonzeros,
239242
desc
240243
)
244+
isiso == C_NULL && (isiso = false)
241245
nstored = isiso[] ? 1 : szA
242246
v = unsafe_wrap(Array, Ptr{T}(values[]), nstored)
243247
b = unsafe_wrap(Array, bytemap[], szA)
@@ -266,7 +270,7 @@ function _unpackhypermatrix!(
266270
rowidxsize = Ref{LibGraphBLAS.GrB_Index}()
267271
valsize = Ref{LibGraphBLAS.GrB_Index}()
268272
nvec = Ref{LibGraphBLAS.GrB_Index}()
269-
isiso = Ref{Bool}(allowiso ? true : C_NULL)
273+
isiso = allowiso ? Ref{Bool}(true) : C_NULL
270274
isjumbled = C_NULL
271275
nnonzeros = nnz(A)
272276

@@ -285,6 +289,7 @@ function _unpackhypermatrix!(
285289
isjumbled,
286290
desc
287291
)
292+
isiso == C_NULL && (isiso = false)
288293
nvec = nvec[]
289294
colptr = unsafe_wrap(Array, Ptr{Int64}(colptr[]), nvec + 1)
290295
colidx = unsafe_wrap(Array, Ptr{Int64}(colidx), nvec)
@@ -329,7 +334,7 @@ desc = _handledescriptor(desc)
329334
colidxsize = Ref{LibGraphBLAS.GrB_Index}()
330335
valsize = Ref{LibGraphBLAS.GrB_Index}()
331336
nvec = Ref{LibGraphBLAS.GrB_Index}()
332-
isiso = Ref{Bool}(allowiso ? true : C_NULL)
337+
isiso = allowiso ? Ref{Bool}(true) : C_NULL
333338
isjumbled = C_NULL
334339
nnonzeros = nnz(A)
335340

@@ -348,6 +353,7 @@ desc = _handledescriptor(desc)
348353
isjumbled,
349354
desc
350355
)
356+
isiso == C_NULL && (isiso = false)
351357
nvec = nvec[]
352358
rowptr = unsafe_wrap(Array, Ptr{Int64}(rowptr[]), nvec + 1)
353359
rowidx = unsafe_wrap(Array, Ptr{Int64}(rowidx[]), nvec)

test/operations/broadcasting.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,43 @@
1818
u = rand(1000)
1919
v = GBVector(u)
2020
@test sin.(u) Vector(sin.(v))
21+
@testset "Dimensional Broadcasting" begin
22+
A = rand(3,5)
23+
u = rand(3)
24+
v = rand(5)
25+
26+
G = GBMatrix(A)
27+
uG = GBVector(u)
28+
vG = GBVector(v)
29+
30+
@test A .* u G .* uG u .* A uG .* G
31+
@test_throws DimensionMismatch G .* vG
32+
33+
@test A .* v' G .* vG' v' .* A vG' .* G
34+
@test_throws DimensionMismatch uG' .* G
35+
@test_throws DimensionMismatch G .* uG'
36+
37+
@test u .* u uG .* uG
38+
@test_throws DimensionMismatch uG .* vG
39+
@test u' .* u' uG' .* uG'
40+
@test_throws DimensionMismatch uG' .* vG'
41+
42+
@test u .* v' uG .* vG'
43+
@test v' .* u vG' .* uG
44+
@test v .* u' vG .* uG'
45+
@test u' .* v uG' .* vG
46+
47+
# tests without a _swapop
48+
@test A .^ u G .^ uG
49+
@test u .^ A uG .^ G
50+
@test A .^ v' G .^ vG'
51+
@test v' .^ A vG' .^ G
52+
@test u' .^ u' uG' .^ uG'
53+
54+
@test u .^ v' uG .^ vG'
55+
@test v' .^ u vG' .^ uG
56+
@test v .^ u' vG .^ uG'
57+
@test u' .^ v uG' .^ vG
58+
end
2159
end
2260

0 commit comments

Comments
 (0)