@@ -137,11 +137,39 @@ function to_sparse(coo::COO_T, T::DataType=Int; dir=:out, num_nodes=nothing)
137
137
s, t, eweight = coo
138
138
eweight = isnothing (eweight) ? fill! (similar (s, T), 1 ) : eweight
139
139
num_nodes = isnothing (num_nodes) ? max (maximum (s), maximum (t)) : num_nodes
140
- A = sparse (s, t, eweight, num_nodes, num_nodes)
140
+ A = _sparse (s, t, eweight, num_nodes, num_nodes)
141
141
num_edges = length (s)
142
142
return A, num_nodes, num_edges
143
143
end
144
144
145
+ _sparse (s, t, eweight, n, m) = sparse (s, t, eweight, n, m)
146
+
147
+ function _sparse (I:: CuVector , J:: CuVector , V:: CuVector , m, n)
148
+ spcoo = CuSparseMatrixCOO {Float32, Int32} (Int32 .(I), Int32 .(J), Float32 .(V), (m, n))
149
+ return CuSparseMatrixCSR (spcoo)
150
+ end
151
+
152
+ # function _sparse(I::CuVector, J::CuVector, V::CuVector, m, n; fmt=:csr)
153
+ # # Tv = Int32
154
+ # spcoo = CuSparseMatrixCOO{Float32, Int32}(Int32.(I), Int32.(J), Float32.(V), (m, n))
155
+ # if fmt == :csc
156
+ # return CuSparseMatrixCSC(spcoo)
157
+ # elseif fmt == :csr
158
+ # return CuSparseMatrixCSR(spcoo)
159
+ # elseif fmt == :coo
160
+ # return spcoo
161
+ # else
162
+ # error("Format :$fmt not available, use :csc, :csr, or :coo.")
163
+ # end
164
+ # end
165
+
166
+
167
+ # Workaround for https://github.com/JuliaGPU/CUDA.jl/issues/1113#issuecomment-955759875
168
+ function Base.:* (A:: CuMatrix , B:: CuSparseMatrixCSR )
169
+ @assert size (A, 2 ) == size (B, 1 )
170
+ return CuMatrix ((B' * A' )' )
171
+ end
172
+
145
173
146
174
@non_differentiable to_coo (x... )
147
175
@non_differentiable to_dense (x... )
0 commit comments