44#include < spblas/concepts.hpp>
55#include < spblas/detail/log.hpp>
66
7+ #include < spblas/algorithms/transposed.hpp>
78#include < spblas/backend/csr_builder.hpp>
89#include < spblas/backend/spa_accumulator.hpp>
910#include < spblas/detail/operation_info_t.hpp>
@@ -95,14 +96,27 @@ void multiply(A&& a, B&& b, C&& c) {
9596 try {
9697 c_builder.insert_row (i, c_row.get ());
9798 } catch (...) {
98- throw std::runtime_error (" multiply: ran out of memory. CSR output view "
99- " has insufficient memory." );
99+ throw std::runtime_error (" multiply: SpGEMM ran out of memory." );
100100 }
101101 }
102102 c.update (c.values (), c.rowptr (), c.colind (), c.shape (),
103103 c.rowptr ()[c.shape ()[0 ]]);
104104}
105105
106+ template <matrix A, matrix B, matrix C>
107+ requires (__backend::column_iterable<A> && __backend::column_iterable<B> &&
108+ __detail::is_csc_view_v<C>)
109+ void multiply (A&& a, B&& b, C&& c) {
110+ log_trace (" " );
111+ if (__backend::shape (a)[0 ] != __backend::shape (c)[0 ] ||
112+ __backend::shape (b)[1 ] != __backend::shape (c)[1 ] ||
113+ __backend::shape (a)[1 ] != __backend::shape (b)[0 ]) {
114+ throw std::invalid_argument (
115+ " multiply: matrix dimensions are incompatible." );
116+ }
117+ multiply (transposed (b), transposed (a), transposed (c));
118+ }
119+
106120template <matrix A, matrix B, matrix C>
107121operation_info_t multiply_inspect (A&& a, B&& b, C&& c) {
108122 return operation_info_t {};
@@ -147,6 +161,26 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
147161 return operation_info_t {__backend::shape (c), nnz};
148162}
149163
164+ // C = AB
165+ // SpGEMM (Gustavson's Algorithm, transposed)
166+ template <matrix A, matrix B, matrix C>
167+ requires (__backend::column_iterable<A> && __backend::column_iterable<B> &&
168+ __detail::is_csc_view_v<C>)
169+ operation_info_t multiply_compute (A&& a, B&& b, C&& c) {
170+ log_trace (" " );
171+ if (__backend::shape (a)[0 ] != __backend::shape (c)[0 ] ||
172+ __backend::shape (b)[1 ] != __backend::shape (c)[1 ] ||
173+ __backend::shape (a)[1 ] != __backend::shape (b)[0 ]) {
174+ throw std::invalid_argument (
175+ " multiply: matrix dimensions are incompatible." );
176+ }
177+
178+ auto info = multiply_compute (transposed (b), transposed (a), transposed (c));
179+ info.update_impl_ ({info.result_shape ()[1 ], info.result_shape ()[0 ]},
180+ info.result_nnz ());
181+ return info;
182+ }
183+
150184template <matrix A, matrix B, matrix C>
151185 requires (__backend::row_iterable<A> && __backend::row_iterable<B> &&
152186 __detail::is_csr_view_v<C>)
@@ -156,6 +190,15 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) {
156190 info.update_impl_ (new_info.result_shape (), new_info.result_nnz ());
157191}
158192
193+ template <matrix A, matrix B, matrix C>
194+ requires (__backend::column_iterable<A> && __backend::column_iterable<B> &&
195+ __detail::is_csc_view_v<C>)
196+ void multiply_compute (operation_info_t & info, A&& a, B&& b, C&& c) {
197+ auto new_info = multiply_compute (std::forward<A>(a), std::forward<B>(b),
198+ std::forward<C>(c));
199+ info.update_impl_ (new_info.result_shape (), new_info.result_nnz ());
200+ }
201+
159202// C = AB
160203template <matrix A, matrix B, matrix C>
161204void multiply_fill (operation_info_t info, A&& a, B&& b, C&& c) {
0 commit comments