Skip to content

Commit bf4b8c4

Browse files
make GrB_Type parametric
1 parent be4561d commit bf4b8c4

File tree

5 files changed

+21
-52
lines changed

5 files changed

+21
-52
lines changed

src/Object_Methods/Matrix_Methods.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@ julia> GrB_Matrix_new(MAT, GrB_INT8, 4, 4)
1717
GrB_SUCCESS::GrB_Info = 0
1818
```
1919
"""
20-
function GrB_Matrix_new(A::GrB_Matrix{T}, type::GrB_Type, nrows::U, ncols::U) where {U <: GrB_Index, T <: valid_types}
20+
function GrB_Matrix_new(A::GrB_Matrix{T}, type::GrB_Type{T}, nrows::U, ncols::U) where {U <: GrB_Index, T <: valid_types}
2121
A_ptr = pointer_from_objref(A)
22-
if jl_type(type) != T
23-
error("Domain and matrix type do not match")
24-
end
22+
2523
return GrB_Info(
2624
ccall(
2725
dlsym(graphblas_lib, "GrB_Matrix_new"),

src/Object_Methods/Vector_Methods.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@ julia> GrB_Vector_new(V, GrB_FP64, 4)
1717
GrB_SUCCESS::GrB_Info = 0
1818
```
1919
"""
20-
function GrB_Vector_new(v::GrB_Vector{T}, type::GrB_Type, n::U) where {U <: GrB_Index, T <: valid_types}
20+
function GrB_Vector_new(v::GrB_Vector{T}, type::GrB_Type{T}, n::U) where {U <: GrB_Index, T <: valid_types}
2121
v_ptr = pointer_from_objref(v)
22-
if jl_type(type) != T
23-
error("Domain and vector type do not match")
24-
end
22+
2523
return GrB_Info(
2624
ccall(
2725
dlsym(graphblas_lib, "GrB_Vector_new"),

src/Structures.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@ import Base.show
22
import Base.==
33
export GrB_Type, GrB_UnaryOp, GrB_BinaryOp, GrB_Vector, GrB_Matrix, GrB_Descriptor
44

5-
mutable struct GrB_Type
5+
mutable struct GrB_Type{T}
66
p::Ptr{Cvoid}
77
end
8-
GrB_Type() = GrB_Type(Ptr{Cvoid}(0))
9-
Base.show(io::IO, ::GrB_Type) = print("GrB_Type")
10-
function ==(t1::GrB_Type, t2::GrB_Type)
11-
t1.p == t2.p
12-
end
8+
GrB_Type{T}() where T = GrB_Type{T}(Ptr{Cvoid}(0))
9+
Base.show(io::IO, ::GrB_Type{T}) where T = print("GrB_Type{" * string(T) * "}")
1310

1411
mutable struct GrB_UnaryOp
1512
p::Ptr{Cvoid}
@@ -39,4 +36,4 @@ mutable struct GrB_Descriptor
3936
p::Ptr{Cvoid}
4037
end
4138
GrB_Descriptor() = GrB_Descriptor(Ptr{Cvoid}(0))
42-
Base.show(io::IO, ::GrB_Descriptor) = print("GrB_Descriptor")
39+
Base.show(io::IO, ::GrB_Descriptor) = print("GrB_Descriptor")

src/SuiteSparseGraphBLAS.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,22 @@ if !isfile(depsjl_path)
77
error("SuiteSparseGraphBLAS not installed properly, run Pkg.build(\"SuiteSparseGraphBLAS\"), restart Julia and try again")
88
end
99

10-
include(depsjl_path)
11-
include("Structures.jl")
12-
13-
types = ["BOOL", "INT8", "UINT8", "INT16", "UINT16", "INT32", "UINT32",
14-
"INT64", "UINT64", "FP32", "FP64"]
10+
const types = ["BOOL", "INT8", "UINT8", "INT16", "UINT16", "INT32", "UINT32",
11+
"INT64", "UINT64", "FP32", "FP64"]
1512

16-
GrB_Index = Union{Int64, UInt64}
17-
valid_types = Union{Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float32, Float64}
18-
valid_int_types = Union{Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64}
13+
const GrB_Index = Union{Int64, UInt64}
14+
const valid_types = Union{Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float32, Float64}
15+
const valid_int_types = Union{Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64}
1916

2017
unary_operators = ["IDENTITY", "AINV", "MINV"]
2118

2219
binary_operators = ["EQ", "NE", "GT", "LT", "GE", "LE", "FIRST", "SECOND", "MIN", "MAX",
2320
"PLUS", "MINUS", "TIMES", "DIV"]
2421

22+
include(depsjl_path)
23+
include("Structures.jl")
24+
include("Utils.jl")
25+
2526
const GrB_LNOT = GrB_UnaryOp()
2627
const GrB_LOR = GrB_BinaryOp(); const GrB_LAND = GrB_BinaryOp(); const GrB_LXOR = GrB_BinaryOp()
2728
graphblas_lib = C_NULL
@@ -39,9 +40,10 @@ function __init__()
3940
end
4041

4142
#load global types
42-
for t in types
43-
x = GrB_Type(load_global("GrB_"*t))
44-
@eval const $(Symbol(:GrB_, t)) = $x
43+
for t in [Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float32, Float64]
44+
type_suffix = suffix(t)
45+
x = GrB_Type{t}(load_global("GrB_"*type_suffix))
46+
@eval const $(Symbol(:GrB_, type_suffix)) = $x
4547
end
4648

4749
#load global unary operators
@@ -69,7 +71,6 @@ end
6971

7072
include("Enums.jl")
7173
include("Context_Methods.jl")
72-
include("Utils.jl")
7374
include("Object_Methods/Matrix_Methods.jl")
7475
include("Object_Methods/Vector_Methods.jl")
7576
include("Object_Methods/Descriptor_Methods.jl")

src/Utils.jl

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,6 @@ function suffix(T::DataType)
2323
return "FP64"
2424
end
2525

26-
function jl_type(T::GrB_Type)
27-
if T == GrB_BOOL
28-
return Bool
29-
elseif T == GrB_INT8
30-
return Int8
31-
elseif T == GrB_UINT8
32-
return UInt8
33-
elseif T == GrB_INT16
34-
return Int16
35-
elseif T == GrB_UINT16
36-
return UInt16
37-
elseif T == GrB_INT32
38-
return Int32
39-
elseif T == GrB_UINT32
40-
return UInt32
41-
elseif T == GrB_INT64
42-
return Int64
43-
elseif T == GrB_UINT64
44-
return UInt64
45-
elseif T == GrB_FP32
46-
return Float32
47-
end
48-
return Float64
49-
end
50-
5126
function _GrB_Index(x::T) where T <: GrB_Index
5227
x > typemax(Int64) && return x
5328
return Int64(x)

0 commit comments

Comments
 (0)