-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathSparseMatrixColoringsCUDAExt.jl
More file actions
119 lines (102 loc) · 4.15 KB
/
SparseMatrixColoringsCUDAExt.jl
File metadata and controls
119 lines (102 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
module SparseMatrixColoringsCUDAExt
import SparseMatrixColorings as SMC
using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
using CUDA: CuVector, CuMatrix
using CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
SMC.matrix_versions(A::AbstractCuSparseMatrix) = (A,)
## Compression (slow, through CPU)
function SMC.compress(
A::AbstractCuSparseMatrix, result::SMC.AbstractColoringResult{structure,:column}
) where {structure}
return CuMatrix(SMC.compress(SparseMatrixCSC(A), result))
end
function SMC.compress(
A::AbstractCuSparseMatrix, result::SMC.AbstractColoringResult{structure,:row}
) where {structure}
return CuMatrix(SMC.compress(SparseMatrixCSC(A), result))
end
## CSC Result
function SMC.ColumnColoringResult(
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
) where {T<:Integer}
group = SMC.group_by_color(T, color)
compressed_indices = SMC.column_csc_indices(bg, color)
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
return SMC.ColumnColoringResult(
A, bg, color, group, compressed_indices, additional_info
)
end
function SMC.RowColoringResult(
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
) where {T<:Integer}
group = SMC.group_by_color(T, color)
compressed_indices = SMC.row_csc_indices(bg, color)
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
end
function SMC.StarSetColoringResult(
A::CuSparseMatrixCSC,
ag::SMC.AdjacencyGraph{T},
color::Vector{<:Integer},
star_set::SMC.StarSet{<:Integer},
) where {T<:Integer}
group = SMC.group_by_color(T, color)
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
return SMC.StarSetColoringResult(
A, ag, color, group, compressed_indices, additional_info
)
end
## CSR Result
function SMC.ColumnColoringResult(
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
) where {T<:Integer}
group = SMC.group_by_color(T, color)
compressed_indices = SMC.column_csc_indices(bg, color)
compressed_indices_csr = SMC.column_csr_indices(bg, color)
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices_csr))
return SMC.ColumnColoringResult(
A, bg, color, group, compressed_indices, additional_info
)
end
function SMC.RowColoringResult(
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
) where {T<:Integer}
group = SMC.group_by_color(T, color)
compressed_indices = SMC.row_csc_indices(bg, color)
compressed_indices_csr = SMC.row_csr_indices(bg, color)
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices_csr))
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
end
function SMC.StarSetColoringResult(
A::CuSparseMatrixCSR,
ag::SMC.AdjacencyGraph{T},
color::Vector{<:Integer},
star_set::SMC.StarSet{<:Integer},
) where {T<:Integer}
group = SMC.group_by_color(T, color)
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
return SMC.StarSetColoringResult(
A, ag, color, group, compressed_indices, additional_info
)
end
## Decompression
for R in (:ColumnColoringResult, :RowColoringResult, :StarSetColoringResult)
# loop to avoid method ambiguity
@eval function SMC.decompress!(
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.$R{<:CuSparseMatrixCSC}
)
compressed_indices = result.additional_info.compressed_indices_gpu_csc
map!(Base.Fix1(getindex, B), A.nzVal, compressed_indices)
return A
end
@eval function SMC.decompress!(
A::CuSparseMatrixCSR, B::CuMatrix, result::SMC.$R{<:CuSparseMatrixCSR}
)
compressed_indices = result.additional_info.compressed_indices_gpu_csr
map!(Base.Fix1(getindex, B), A.nzVal, compressed_indices)
return A
end
end
end