Skip to content

Commit 16819ef

Browse files
committed
descriptor rework, add reduce error for masking. TODO: determine shape of reduction result
1 parent 2e67430 commit 16819ef

File tree

5 files changed

+65
-81
lines changed

5 files changed

+65
-81
lines changed

src/descriptors.jl

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -39,69 +39,49 @@ Base.unsafe_convert(::Type{LibGraphBLAS.GrB_Descriptor}, d::Descriptor) = d.p
3939
Base.:+(a::LibGraphBLAS.GrB_Desc_Value, b::LibGraphBLAS.GrB_Desc_Value) =
4040
Integer(a) + Integer(b)
4141

42+
function symtodescfield(sym::Symbol)
43+
sym === :replace_output && return LibGraphBLAS.GrB_OUTP
44+
(sym === :complement_mask || sym === :structural_mask) &&
45+
return LibGraphBLAS.GrB_MASK
46+
sym === :transpose_input1 && return LibGraphBLAS.GrB_INP0
47+
sym === :transpose_input2 && return LibGraphBLAS.GrB_INP1
48+
sym === :nthreads && return LibGraphBLAS.GxB_DESCRIPTOR_NTHREADS
49+
sym === :chunk && return LibGraphBLAS.GxB_DESCRIPTOR_CHUNK
50+
sym === :sort && return LibGraphBLAS.GxB_SORT
51+
throw(ArgumentError("$sym is not a valid Descriptor field"))
52+
end
53+
function descfieldtype(sym::Symbol)
54+
if sym [:replace_output, :transpose_input1, :transpose_input2]
55+
return LibGraphBLAS.GrB_Desc_Value
56+
elseif sym [:chunk]
57+
return Float64
58+
else
59+
return Int32
60+
end
61+
end
62+
function Desc_get(d::Descriptor, field::Symbol)
63+
o = Ref{descfieldtype(field)}()
64+
@wraperror LibGraphBLAS.GxB_Desc_get(d, symtodescfield(field), o)
65+
return o[]
66+
end
67+
4268
function Base.getproperty(d::Descriptor, s::Symbol)
4369
if s === :p
4470
return getfield(d, s)
45-
elseif s === :replace_output
46-
x = LibGraphBLAS.GxB_Desc_get(d, LibGraphBLAS.GrB_OUTP)
47-
if x == LibGraphBLAS.GrB_REPLACE
48-
return true
49-
else
50-
return false
51-
end
71+
end
72+
x = Desc_get(d, s)
73+
if s === :replace_output
74+
return x == LibGraphBLAS.GrB_REPLACE
5275
elseif s === :complement_mask
53-
x = LibGraphBLAS.GxB_Desc_get(d, LibGraphBLAS.GrB_MASK)
54-
if x == LibGraphBLAS.GrB_COMP || x == (LibGraphBLAS.GrB_STRUCTURE + LibGraphBLAS.GrB_COMP)
55-
return true
56-
else
57-
return false
58-
end
76+
return x == Integer(LibGraphBLAS.GrB_COMP) || x ==
77+
(LibGraphBLAS.GrB_STRUCTURE + LibGraphBLAS.GrB_COMP)
5978
elseif s === :structural_mask
60-
x = LibGraphBLAS.GxB_Desc_get(d, LibGraphBLAS.GrB_MASK)
61-
if x == LibGraphBLAS.GrB_STRUCTURE || x == (LibGraphBLAS.GrB_STRUCTURE + LibGraphBLAS.GrB_COMP)
62-
return true
63-
else
64-
return false
65-
end
66-
elseif s === :transpose_input1
67-
x = LibGraphBLAS.GxB_Desc_get(d, LibGraphBLAS.GrB_INP0)
68-
if x == LibGraphBLAS.GrB_TRAN
69-
return true
70-
else
71-
return false
72-
end
73-
elseif s === :transpose_input2
74-
x = LibGraphBLAS.GxB_Desc_get(d, LibGraphBLAS.GrB_INP1)
75-
if x == LibGraphBLAS.GrB_TRAN
76-
return true
77-
else
78-
return false
79-
end
80-
elseif s === :nthreads
81-
return LibGraphBLAS.GxB_Desc_get(d, LibGraphBLAS.GxB_DESCRIPTOR_NTHREADS)
82-
elseif s === :chunk
83-
return LibGraphBLAS.GxB_Desc_get(d, LibGraphBLAS.GxB_DESCRIPTOR_CHUNK)
84-
elseif s === :sort
85-
if LibGraphBLAS.GxB_Desc_get(d, LibGraphBLAS.GxB_SORT) == LibGraphBLAS.GxB_DEFAULT
86-
return false
87-
else
88-
return true
89-
end
90-
elseif s === :axb_method
91-
x = LibGraphBLAS.GxB_Desc_get(d, LibGraphBLAS.GxB_AxB_METHOD)
92-
if x == LibGraphBLAS.GxB_AxB_GUSTAVSON
93-
return :gustavson
94-
elseif x == LibGraphBLAS.GxB_AxB_DOT
95-
return :dot
96-
elseif x == LibGraphBLAS.AxB_HASH
97-
return :hash
98-
elseif x == LibGraphBLAS.GxB_AxB_SAXPY
99-
return :saxpy
100-
else
101-
return :default
102-
end
103-
else
104-
return getfield(d, s)
79+
return x == Integer(LibGraphBLAS.GrB_STRUCTURE) || x ==
80+
(LibGraphBLAS.GrB_STRUCTURE + LibGraphBLAS.GrB_COMP)
81+
elseif s === :transpose_input1 || s === :transpose_input2
82+
return x == LibGraphBLAS.GrB_TRAN
83+
elseif s === :nthreads || s === :chunk || s === :sort
84+
return x
10585
end
10686
end
10787

src/lib/LibGraphBLAS_gen.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LibGraphBLAS
22
import ..libgraphblas
33
to_c_type(t::Type) = t
4+
to_c_type(::Type{Base.RefValue{T}}) where T = Base.Ptr{T}
45
to_c_type_pairs(va_list) = map(enumerate(to_c_type.(va_list))) do (ind, type)
56
:(va_list[$ind]::$type)
67
end

src/operations/reduce.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ function Base.reduce(
3939
reduce!(op, w, A; desc, accum, mask)
4040
return w
4141
elseif dims == (1,2) || dims == Colon() || A isa GBVectorOrTranspose
42+
mask !== nothing && throw(
43+
ArgumentError("Reduction to a scalar does not support masking."))
4244
if init === nothing
4345
c = GBScalar{typeout}()
4446
else

test/operations/reduce.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testset "reduce" begin
2-
# m = GBMatrix([[1,2,3] [4,5,6] [7,8,9]])
3-
# @test reduce(max, m, dims=2) == reduce(max, m) #this only works for dense
4-
# @test reduce(max, m, dims=(1,2)) == 9
5-
# @test_throws ArgumentError reduce(*, m) ?? I don't recognize this test. And it doesn't pass in older versions?
2+
@testset "Reduction of Vec -> Scalar" begin
3+
v = GBVector(1:10)
4+
@test reduce(+, v) == 55
5+
@test reduce(*, v) == 3628800
6+
end
67
end

test/runtests.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,25 @@ SuiteSparseGraphBLAS.@unop foo F=>F
7070
println("Testing SuiteSparseGraphBLAS.jl")
7171
println("$(SuiteSparseGraphBLAS.get_lib())")
7272
@testset "SuiteSparseGraphBLAS" begin
73-
include_test("libutils.jl")
74-
include_test("operatorutils.jl")
75-
include_test("ops.jl")
76-
include_test("gbarray.jl")
77-
include_test("issues.jl")
78-
include_test("operations/ewise.jl")
79-
include_test("operations/kron.jl")
80-
include_test("operations/map.jl")
81-
include_test("operations/mul.jl")
73+
# include_test("libutils.jl")
74+
# include_test("operatorutils.jl")
75+
# include_test("ops.jl")
76+
# include_test("gbarray.jl")
77+
# include_test("issues.jl")
78+
# include_test("operations/ewise.jl")
79+
# include_test("operations/kron.jl")
80+
# include_test("operations/map.jl")
81+
# include_test("operations/mul.jl")
8282
include_test("operations/reduce.jl")
83-
include_test("operations/select.jl")
84-
include_test("operations/transpose.jl")
85-
include_test("operations/broadcasting.jl")
86-
include_test("operations/concat.jl")
87-
include_test("chainrules/chainrulesutils.jl")
88-
include_test("chainrules/mulrules.jl")
89-
include_test("chainrules/ewiserules.jl")
90-
include_test("chainrules/selectrules.jl")
91-
include_test("chainrules/constructorrules.jl")
92-
include_test("chainrules/maprules.jl")
83+
# include_test("operations/select.jl")
84+
# include_test("operations/transpose.jl")
85+
# include_test("operations/broadcasting.jl")
86+
# include_test("operations/concat.jl")
87+
# include_test("chainrules/chainrulesutils.jl")
88+
# include_test("chainrules/mulrules.jl")
89+
# include_test("chainrules/ewiserules.jl")
90+
# include_test("chainrules/selectrules.jl")
91+
# include_test("chainrules/constructorrules.jl")
92+
# include_test("chainrules/maprules.jl")
9393

9494
end

0 commit comments

Comments
 (0)