@@ -129,6 +129,12 @@ template <typename ValueType> class CSRMatrix : public Matrix<ValueType> {
129129 rowOffsets = std::shared_ptr<size_t []>(src->rowOffsets , src->rowOffsets .get () + rowLowerIncl);
130130 }
131131
132+ CSRMatrix (size_t numRows, size_t numCols, size_t numNonZeros, std::shared_ptr<ValueType[]> &values,
133+ std::shared_ptr<size_t []> &colIdxs, std::shared_ptr<size_t []> &rowOffsets)
134+ : Matrix<ValueType>(numRows, numCols), numRowsAllocated(numRows), isRowAllocatedBefore(false ),
135+ maxNumNonZeros (numNonZeros), values(values), colIdxs(colIdxs), rowOffsets(rowOffsets), lastAppendedRowIdx(0 ) {
136+ }
137+
132138 virtual ~CSRMatrix () {
133139 // nothing to do
134140 }
@@ -163,11 +169,20 @@ template <typename ValueType> class CSRMatrix : public Matrix<ValueType> {
163169 }
164170
165171 void shrinkNumNonZeros (size_t numNonZeros) {
166- if (numNonZeros > getNumNonZeros ())
172+ size_t actualNumNonZeros = getNumNonZeros ();
173+ if (numNonZeros > actualNumNonZeros)
167174 throw std::runtime_error (" CSRMatrix (shrinkNumNonZeros): "
168- " numNonZeros can only be shrunk" );
169- // TODO Here we could reduce the allocated size of the values and
170- // colIdxs arrays.
175+ " cannot shrink below actual non-zero count" );
176+ // allocate new buffers
177+ auto newValues = std::shared_ptr<ValueType[]>(new ValueType[numNonZeros], std::default_delete<ValueType[]>());
178+ auto newColIdxs = std::shared_ptr<size_t []>(new size_t [numNonZeros], std::default_delete<size_t []>());
179+ // copy first numNonZeros entries
180+ std::memcpy (newValues.get (), values.get (), numNonZeros * sizeof (ValueType));
181+ std::memcpy (newColIdxs.get (), colIdxs.get (), numNonZeros * sizeof (size_t ));
182+
183+ values = std::move (newValues);
184+ colIdxs = std::move (newColIdxs);
185+ maxNumNonZeros = numNonZeros;
171186 }
172187
173188 ValueType *getValues () { return values.get (); }
0 commit comments