Skip to content

Commit cd80b50

Browse files
LuloDuartepdamme
authored andcommitted
Shrink number of non zeroes for CSRMatrix
1 parent 3e46578 commit cd80b50

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

src/runtime/local/datastructures/CSRMatrix.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17+
#include <runtime/local/datastructures/AllocationDescriptorHost.h>
1718
#include <runtime/local/io/DaphneSerializer.h>
1819

1920
#include "CSRMatrix.h"

src/runtime/local/datastructures/CSRMatrix.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ template <typename ValueType> class CSRMatrix : public Matrix<ValueType> {
129129
rowOffsets = std::shared_ptr<size_t[]>(src->rowOffsets, src->rowOffsets.get() + rowLowerIncl);
130130
}
131131

132+
CSRMatrix(size_t numRows, size_t numCols, size_t numNonZeros, std::shared_ptr<ValueType[]> &values,
133+
std::shared_ptr<size_t[]> &colIdxs, std::shared_ptr<size_t[]> &rowOffsets)
134+
: Matrix<ValueType>(numRows, numCols), numRowsAllocated(numRows), isRowAllocatedBefore(false),
135+
maxNumNonZeros(numNonZeros), values(values), colIdxs(colIdxs), rowOffsets(rowOffsets), lastAppendedRowIdx(0) {
136+
}
137+
132138
virtual ~CSRMatrix() {
133139
// nothing to do
134140
}
@@ -163,11 +169,20 @@ template <typename ValueType> class CSRMatrix : public Matrix<ValueType> {
163169
}
164170

165171
void shrinkNumNonZeros(size_t numNonZeros) {
166-
if (numNonZeros > getNumNonZeros())
172+
size_t actualNumNonZeros = getNumNonZeros();
173+
if (numNonZeros > actualNumNonZeros)
167174
throw std::runtime_error("CSRMatrix (shrinkNumNonZeros): "
168-
"numNonZeros can only be shrunk");
169-
// TODO Here we could reduce the allocated size of the values and
170-
// colIdxs arrays.
175+
"cannot shrink below actual non-zero count");
176+
// allocate new buffers
177+
auto newValues = std::shared_ptr<ValueType[]>(new ValueType[numNonZeros], std::default_delete<ValueType[]>());
178+
auto newColIdxs = std::shared_ptr<size_t[]>(new size_t[numNonZeros], std::default_delete<size_t[]>());
179+
// copy first numNonZeros entries
180+
std::memcpy(newValues.get(), values.get(), numNonZeros * sizeof(ValueType));
181+
std::memcpy(newColIdxs.get(), colIdxs.get(), numNonZeros * sizeof(size_t));
182+
183+
values = std::move(newValues);
184+
colIdxs = std::move(newColIdxs);
185+
maxNumNonZeros = numNonZeros;
171186
}
172187

173188
ValueType *getValues() { return values.get(); }

0 commit comments

Comments
 (0)