Skip to content

Commit 85f981d

Browse files
committed
Add binary file saving methods to SparseMatrix
1 parent 03c07d6 commit 85f981d

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

src/dsm/dsm.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
static constexpr uint8_t DSM_VERSION_MAJOR = 2;
88
static constexpr uint8_t DSM_VERSION_MINOR = 3;
9-
static constexpr uint8_t DSM_VERSION_PATCH = 15;
9+
static constexpr uint8_t DSM_VERSION_PATCH = 16;
1010

1111
static auto const DSM_VERSION =
1212
std::format("{}.{}.{}", DSM_VERSION_MAJOR, DSM_VERSION_MINOR, DSM_VERSION_PATCH);

src/dsm/headers/SparseMatrix.hpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <vector>
1818
#include <cmath>
1919
#include <format>
20+
#include <fstream>
2021

2122
#include "../utility/Logger.hpp"
2223
#include "../utility/Typedef.hpp"
@@ -190,6 +191,13 @@ namespace dsm {
190191
/// @brief reshape the matrix
191192
void reshape(Id dim);
192193

194+
/// @brief save the matrix to a binary file (.dsmcache)
195+
/// @param filename the name of the file
196+
void cache(std::string const& filename) const;
197+
/// @brief load the matrix from a binary file (.dsmcache)
198+
/// @param filename the name of the file
199+
void load(std::string const& filename);
200+
193201
/// @brief return the begin iterator of the matrix
194202
/// @return the begin iterator
195203
typename std::unordered_map<Id, T>::const_iterator begin() const {
@@ -632,6 +640,43 @@ namespace dsm {
632640
}
633641
}
634642

643+
template <typename T>
644+
void SparseMatrix<T>::cache(std::string const& filename) const {
645+
std::ofstream file(filename, std::ios::binary);
646+
if (!file.is_open()) {
647+
Logger::error(std::format("Error opening file \"{}\" for writing.", filename));
648+
}
649+
file.write(reinterpret_cast<const char*>(&_rows), sizeof(Id));
650+
file.write(reinterpret_cast<const char*>(&_cols), sizeof(Id));
651+
size_t size = _matrix.size();
652+
file.write(reinterpret_cast<const char*>(&size), sizeof(size_t));
653+
for (auto& it : _matrix) {
654+
file.write(reinterpret_cast<const char*>(&it.first), sizeof(Id));
655+
file.write(reinterpret_cast<const char*>(&it.second), sizeof(T));
656+
}
657+
file.close();
658+
}
659+
660+
template <typename T>
661+
void SparseMatrix<T>::load(std::string const& filename) {
662+
std::ifstream file(filename, std::ios::binary);
663+
if (!file.is_open()) {
664+
Logger::error(std::format("Error opening file \"{}\" for reading.", filename));
665+
}
666+
file.read(reinterpret_cast<char*>(&_rows), sizeof(Id));
667+
file.read(reinterpret_cast<char*>(&_cols), sizeof(Id));
668+
size_t size;
669+
file.read(reinterpret_cast<char*>(&size), sizeof(size_t));
670+
for (size_t i = 0; i < size; ++i) {
671+
Id index;
672+
T value;
673+
file.read(reinterpret_cast<char*>(&index), sizeof(Id));
674+
file.read(reinterpret_cast<char*>(&value), sizeof(T));
675+
_matrix.insert_or_assign(index, value);
676+
}
677+
file.close();
678+
}
679+
635680
template <typename T>
636681
const T& SparseMatrix<T>::operator()(Id i, Id j) const {
637682
if (i >= _rows || j >= _cols) {

test/Test_sparsematrix.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,4 +746,25 @@ TEST_CASE("Boolean Matrix") {
746746
CHECK(m.getRowDim() == 5);
747747
CHECK(m.getColDim() == 5);
748748
}
749+
SUBCASE("Caching") {
750+
/*
751+
The caching function should cache the matrix
752+
GIVEN: the caching function is called
753+
WHEN: the function is called on a matrix
754+
THEN: the function should cache the matrix
755+
*/
756+
{
757+
SparseMatrix<bool> m(3, 3);
758+
m.insert(0, 0, true);
759+
m.insert(1, 2, true);
760+
m.cache("./data/test.dsmcache");
761+
}
762+
SparseMatrix<bool> m;
763+
m.load("./data/test.dsmcache");
764+
CHECK(m(0, 0));
765+
CHECK(m(1, 2));
766+
CHECK(m.size() == 2);
767+
CHECK(m.getRowDim() == 3);
768+
CHECK(m.getColDim() == 3);
769+
}
749770
}

0 commit comments

Comments
 (0)