Skip to content

Commit 5af0ee4

Browse files
add equality operator for matrices and vectors
1 parent 7bac8a0 commit 5af0ee4

File tree

4 files changed

+150
-8
lines changed

4 files changed

+150
-8
lines changed

src/Interface/Interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ module Interface
33
using SuiteSparseGraphBLAS
44

55
import Base:
6-
getindex, setindex!, empty!, copy, size
6+
getindex, setindex!, empty!, copy, size, ==
77

88
import SuiteSparseGraphBLAS:
99
GrB_Info, GrB_Index, GrB_Matrix, GrB_Vector, GrB_Descriptor, GrB_Desc_Field, GrB_Desc_Value,
10-
valid_types, get_GrB_Type, default_dup
10+
valid_types, get_GrB_Type, default_dup, equal_op
1111

1212
include("./Object_Methods/Matrix_Methods.jl")
1313
include("./Object_Methods/Vector_Methods.jl")

src/Interface/Object_Methods/Matrix_Methods.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,66 @@ function GrB_Matrix(T::DataType, nrows::GrB_Index, ncols::GrB_Index)
6868
return A
6969
end
7070

71+
"""
72+
==(A, B)
73+
74+
Check if two GraphBLAS matrices are equal.
75+
76+
# Examples
77+
```jldoctest
78+
julia> using SuiteSparseGraphBLAS
79+
80+
julia> GrB_init(GrB_NONBLOCKING)
81+
GrB_SUCCESS::GrB_Info = 0
82+
83+
julia> A = GrB_Matrix([1, 2, 3], [2, 4, 5], [1, 1, 1])
84+
GrB_Matrix{Int64}
85+
86+
julia> B = GrB_Matrix([1, 2, 3], [2, 4, 5], [1, 1, 1])
87+
GrB_Matrix{Int64}
88+
89+
julia> A == B
90+
true
91+
92+
julia> B = GrB_Matrix([1, 2, 3], [2, 4, 5], [1, 1, 2])
93+
GrB_Matrix{Int64}
94+
95+
julia> A == B
96+
false
97+
98+
julia> B = GrB_Matrix([1, 2, 3], [2, 4, 3], [1, 1, 1])
99+
GrB_Matrix{Int64}
100+
101+
julia> A == B
102+
false
103+
```
104+
"""
105+
function ==(A::GrB_Matrix{T}, B::GrB_Matrix{U}) where {T, U}
106+
T != U && return false
107+
108+
Asize = size(A)
109+
Anvals = nnz(A)
110+
111+
Asize == size(B) || return false
112+
Anvals == nnz(B) || return false
113+
114+
C = GrB_Matrix(Bool, Asize...)
115+
op = equal_op(T)
116+
117+
GrB_eWiseMult(C, GrB_NULL, GrB_NULL, op, A, B, GrB_NULL)
118+
119+
if nnz(C) != Anvals
120+
GrB_free(C)
121+
return false
122+
end
123+
124+
result = GrB_reduce(GrB_NULL, GxB_LAND_BOOL_MONOID, C, GrB_NULL)
125+
126+
GrB_free(C)
127+
128+
return result
129+
end
130+
71131
"""
72132
findnz(A)
73133

src/Interface/Object_Methods/Vector_Methods.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,54 @@ function GrB_Vector(T::DataType, n::GrB_Index)
7878
return V
7979
end
8080

81+
"""
82+
==(A, B)
83+
84+
Check if two GraphBLAS vectors are equal.
85+
86+
# Examples
87+
```jldoctest
88+
julia> using SuiteSparseGraphBLAS
89+
90+
julia> GrB_init(GrB_NONBLOCKING)
91+
GrB_SUCCESS::GrB_Info = 0
92+
93+
julia> A = GrB_Vector([1, 3, 4], [1, 1, 1])
94+
GrB_Vector{Int64}
95+
96+
julia> B = GrB_Vector([1, 3, 4], [1, 1, 1])
97+
GrB_Vector{Int64}
98+
99+
julia> A == B
100+
true
101+
```
102+
"""
103+
function ==(A::GrB_Vector{T}, B::GrB_Vector{U}) where {T, U}
104+
T != U && return false
105+
106+
Asize = size(A)
107+
Anvals = nnz(A)
108+
109+
Asize == size(B) || return false
110+
Anvals == nnz(B) || return false
111+
112+
C = GrB_Vector(Bool, Asize[1])
113+
op = equal_op(T)
114+
115+
GrB_eWiseMult(C, GrB_NULL, GrB_NULL, op, A, B, GrB_NULL)
116+
117+
if nnz(C) != Anvals
118+
GrB_free(C)
119+
return false
120+
end
121+
122+
result = GrB_reduce(GrB_NULL, GxB_LAND_BOOL_MONOID, C, GrB_NULL)
123+
124+
GrB_free(C)
125+
126+
return result
127+
end
128+
81129
"""
82130
size(V,[ dim])
83131

src/Utils.jl

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,47 @@ function default_dup(T::DataType)
7373
return GrB_FIRST_FP64
7474
end
7575

76-
function get_struct_name(A::GrB_Struct)
77-
s = ""
78-
for i in string(typeof(A))[5:end]
79-
i == '{' && break
80-
s *= i
76+
function equal_op(T::DataType)
77+
if T == Bool
78+
return GrB_EQ_BOOL
79+
elseif T == Int8
80+
return GrB_EQ_INT8
81+
elseif T == UInt8
82+
return GrB_EQ_UINT8
83+
elseif T == Int16
84+
return GrB_EQ_INT16
85+
elseif T == UInt16
86+
return GrB_EQ_UINT16
87+
elseif T == Int32
88+
return GrB_EQ_INT32
89+
elseif T == UInt32
90+
return GrB_EQ_UINT32
91+
elseif T == Int64
92+
return GrB_EQ_INT64
93+
elseif T == UInt64
94+
return GrB_EQ_UINT64
95+
elseif T == Float32
96+
return GrB_EQ_FP32
97+
end
98+
return GrB_EQ_FP64
99+
end
100+
101+
function get_struct_name(object::GrB_Struct)
102+
T = typeof(object)
103+
if T <: GrB_UnaryOp
104+
return "UnaryOp"
105+
elseif T <: GrB_BinaryOp
106+
return "BinaryOp"
107+
elseif T <: GrB_Monoid
108+
return "Monoid"
109+
elseif T <: GrB_Semiring
110+
return "Semiring"
111+
elseif T <: GrB_Vector
112+
return "Vector"
113+
elseif T <: GrB_Matrix
114+
return "Matrix"
81115
end
82-
return s
116+
return "Descriptor"
83117
end
84118

85119
function _GrB_Index(x::T) where T <: GrB_Index

0 commit comments

Comments
 (0)