99
1010#pragma once
1111
12- #include " aoclsparse.h"
12+ #include < aoclsparse.h>
1313#include < cstdint>
1414
1515#include " aocl_wrappers.hpp"
16+ #include " detail/detail.hpp"
1617#include < fmt/core.h>
18+ #include < spblas/algorithms/transposed.hpp>
1719#include < spblas/detail/log.hpp>
1820#include < spblas/detail/operation_info_t.hpp>
1921#include < spblas/detail/ranges.hpp>
3032namespace spblas {
3133
3234template <matrix A, matrix B, matrix C>
33- requires (__detail::has_csr_base<A>) &&
34- (__detail::has_csr_base<B>) && __detail::is_csr_view_v<C>
35+ requires (__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
36+ (__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
37+ __detail::is_csr_view_v<C>
3538operation_info_t multiply_compute (A&& a, B&& b, C&& c) {
3639 log_trace (" " );
3740 auto a_base = __detail::get_ultimate_base (a);
@@ -41,32 +44,27 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
4144 using I = tensor_index_t <C>;
4245 using O = tensor_offset_t <C>;
4346
44- aoclsparse_matrix csrA = nullptr ;
45- aoclsparse_mat_descr descrA;
47+ aoclsparse_matrix csrA = __aoclsparse::create_matrix_handle (a_base);
48+ aoclsparse_matrix csrB = __aoclsparse::create_matrix_handle (b_base);
49+
50+ aoclsparse_operation opA = __aoclsparse::get_transpose (a);
51+ aoclsparse_operation opB = __aoclsparse::get_transpose (b);
52+
53+ aoclsparse_mat_descr descrA = NULL ;
4654 aoclsparse_status status = aoclsparse_create_mat_descr (&descrA);
4755 if (status != aoclsparse_status_success) {
4856 fmt::print (" \t descr creation failed\n " );
4957 }
5058 aoclsparse_set_mat_type (descrA, aoclsparse_matrix_type_general);
51- aoclsparse_index_base indexingA = aoclsparse_index_base_zero;
52- aoclsparse_operation opA = aoclsparse_operation_none;
5359
54- const index_t a_nrows = __backend::shape (a_base)[0 ];
55- const index_t a_ncols = __backend::shape (a_base)[1 ];
56-
57- aoclsparse_matrix csrB = nullptr ;
58- aoclsparse_mat_descr descrB;
60+ aoclsparse_mat_descr descrB = NULL ;
5961 status = aoclsparse_create_mat_descr (&descrB);
6062 if (status != aoclsparse_status_success) {
6163 fmt::print (" \t descr creation failed\n " );
6264 }
6365
6466 aoclsparse_set_mat_type (descrB, aoclsparse_matrix_type_general);
6567 aoclsparse_index_base indexingB = aoclsparse_index_base_zero;
66- aoclsparse_operation opB = aoclsparse_operation_none;
67-
68- const index_t b_nrows = __backend::shape (b_base)[0 ];
69- const index_t b_ncols = __backend::shape (b_base)[1 ];
7068
7169 aoclsparse_matrix csrC = nullptr ;
7270 aoclsparse_index_base indexingC = aoclsparse_index_base_zero;
@@ -76,22 +74,6 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
7674 index_t * c_colind = nullptr ;
7775 T* c_values = nullptr ;
7876
79- const aoclsparse_int nnzA = a_base.rowptr ().data ()[a_nrows] - indexingA;
80- const aoclsparse_int nnzB = b_base.rowptr ().data ()[b_nrows] - indexingB;
81-
82- status = __aoclsparse::aoclsparse_create_csr (
83- &csrA, indexingA, a_nrows, a_ncols, nnzA, a_base.rowptr ().data (),
84- a_base.colind ().data (), a_base.values ().data ());
85- if (status != aoclsparse_status_success) {
86- fmt::print (" \t csr matrix A creation failed\n " );
87- }
88- status = __aoclsparse::aoclsparse_create_csr (
89- &csrB, indexingB, b_nrows, b_ncols, nnzB, b_base.rowptr ().data (),
90- b_base.colind ().data (), b_base.values ().data ());
91- if (status != aoclsparse_status_success) {
92- fmt::print (" \t csr matrix B creation failed\n " );
93- }
94-
9577 aoclsparse_request request = aoclsparse_stage_nnz_count;
9678 status =
9779 aoclsparse_sp2m (opA, descrA, csrA, opB, descrB, csrB, request, &csrC);
@@ -108,20 +90,22 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
10890 // Check: csrA and csrB destroyed when the operation_info destructor is
10991 // called?
11092
93+ aoclsparse_destroy_mat_descr (descrA);
94+ aoclsparse_destroy_mat_descr (descrB);
11195 return operation_info_t {
11296 index<>{__backend::shape (c)[0 ], __backend::shape (c)[1 ]}, c_nnz,
11397 __aoclsparse::operation_state_t {csrA, csrB, csrC}};
11498}
11599
116100template <matrix A, matrix B, matrix C>
117- requires (__detail::has_csr_base<A>) &&
118- (__detail::has_csr_base<B>) && __detail::is_csr_view_v<C>
101+ requires (__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
102+ (__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
103+ __detail::is_csr_view_v<C>
119104void multiply_fill (operation_info_t & info, A&& a, B&& b, C&& c) {
120105 log_trace (" " );
121106
122107 auto a_base = __detail::get_ultimate_base (a);
123108 auto b_base = __detail::get_ultimate_base (b);
124- auto c_base = __detail::get_ultimate_base (c);
125109
126110 using T = tensor_scalar_t <C>;
127111 using I = tensor_index_t <C>;
@@ -135,35 +119,30 @@ void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) {
135119 aoclsparse_matrix csrC = info.state_ .c_handle ;
136120 offset_t c_nnz = info.result_nnz ();
137121
138- aoclsparse_mat_descr descrA;
122+ aoclsparse_operation opA = __aoclsparse::get_transpose (a);
123+ aoclsparse_operation opB = __aoclsparse::get_transpose (b);
124+
125+ aoclsparse_mat_descr descrA = NULL ;
139126 aoclsparse_status status = aoclsparse_create_mat_descr (&descrA);
140127 if (status != aoclsparse_status_success) {
141128 fmt::print (" \t descr creation failed\n " );
142129 }
143130
144131 aoclsparse_set_mat_type (descrA, aoclsparse_matrix_type_general);
145132 aoclsparse_index_base indexingA = aoclsparse_index_base_zero;
146- aoclsparse_operation opA = aoclsparse_operation_none;
147-
148- const index_t a_nrows = __backend::shape (a_base)[0 ];
149- const index_t a_ncols = __backend::shape (a_base)[1 ];
150133
151- aoclsparse_mat_descr descrB;
134+ aoclsparse_mat_descr descrB = NULL ;
152135 status = aoclsparse_create_mat_descr (&descrB);
153136 if (status != aoclsparse_status_success) {
154137 fmt::print (" \t descr creation failed\n " );
155138 }
156139
157140 aoclsparse_set_mat_type (descrB, aoclsparse_matrix_type_general);
158141 aoclsparse_index_base indexingB = aoclsparse_index_base_zero;
159- aoclsparse_operation opB = aoclsparse_operation_none;
160-
161- const index_t b_nrows = __backend::shape (b_base)[0 ];
162- const index_t b_ncols = __backend::shape (b_base)[1 ];
163142
164143 aoclsparse_index_base indexingC = aoclsparse_index_base_zero;
165- index_t c_nrows = __backend::shape (c_base )[0 ];
166- index_t c_ncols = __backend::shape (c_base )[1 ];
144+ index_t c_nrows = __backend::shape (c )[0 ];
145+ index_t c_ncols = __backend::shape (c )[1 ];
167146 offset_t * c_rowptr = nullptr ;
168147 index_t * c_colind = nullptr ;
169148 T* c_values = nullptr ;
@@ -194,6 +173,24 @@ void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) {
194173 if (alpha_optional.has_value ()) {
195174 scale (alpha, c);
196175 }
176+ aoclsparse_destroy_mat_descr (descrA);
177+ aoclsparse_destroy_mat_descr (descrB);
178+ }
179+
180+ template <matrix A, matrix B, matrix C>
181+ requires (__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
182+ (__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
183+ __detail::is_csc_view_v<C>
184+ operation_info_t multiply_compute (A&& a, B&& b, C&& c) {
185+ return multiply_compute (transposed (b), transposed (a), transposed (c));
186+ }
187+
188+ template <matrix A, matrix B, matrix C>
189+ requires ((__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
190+ (__detail::has_csr_base<B> || __detail::has_csc_base<B>) &&
191+ __detail::is_csc_view_v<C>)
192+ void multiply_fill (operation_info_t & info, A&& a, B&& b, C&& c) {
193+ multiply_fill (info, transposed (b), transposed (a), transposed (c));
197194}
198195
199196} // namespace spblas
0 commit comments