Skip to content

Commit 91dd9ac

Browse files
committed
file extensions in Save() and Load() methods are optional
if no file extension is given, the standard extension for the corresponding object type is added to the filename this is backwards compatible with the previous behavior but allows to use custom file extensions and call Save() and Load() with the exact same arguments
1 parent c5dd9c4 commit 91dd9ac

File tree

11 files changed

+143
-143
lines changed

11 files changed

+143
-143
lines changed

src/Bond.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11

22
#include "Bond.hpp"
3+
34
#include <algorithm>
5+
#include <filesystem>
6+
47
#include "utils/utils.hpp"
8+
59
using namespace std;
610

711
namespace cytnx {
@@ -469,47 +473,39 @@ namespace cytnx {
469473

470474
void Bond::Save(const std::string &fname) const {
471475
fstream f;
472-
f.open((fname + ".cybd"), ios::out | ios::trunc | ios::binary);
473-
if (!f.is_open()) {
474-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
476+
if (std::filesystem::path(fname).has_extension()) {
477+
// filename extension is given
478+
f.open(fname, ios::out | ios::trunc | ios::binary);
479+
} else {
480+
// add filename extension
481+
f.open((fname + ".cybd"), ios::out | ios::trunc | ios::binary);
475482
}
476-
this->_Save(f);
477-
f.close();
478-
}
479-
void Bond::Save(const char *fname) const {
480-
fstream f;
481-
string ffname = string(fname) + ".cybd";
482-
f.open((ffname), ios::out | ios::trunc | ios::binary);
483483
if (!f.is_open()) {
484484
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
485485
}
486486
this->_Save(f);
487487
f.close();
488488
}
489+
void Bond::Save(const char *fname) const { this->Save(string(fname)); }
489490

490491
Bond Bond::Load(const std::string &fname) {
491492
Bond out;
492493
fstream f;
493-
f.open(fname, ios::in | ios::binary);
494-
if (!f.is_open()) {
495-
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
494+
if (std::filesystem::path(fname).has_extension()) {
495+
// filename extension is given
496+
f.open(fname, ios::in | ios::binary);
497+
} else {
498+
// add filename extension
499+
f.open((fname + ".cybd"), ios::in | ios::binary);
496500
}
497-
out._Load(f);
498-
f.close();
499-
return out;
500-
}
501-
502-
Bond Bond::Load(const char *fname) {
503-
Bond out;
504-
fstream f;
505-
f.open(fname, ios::in | ios::binary);
506501
if (!f.is_open()) {
507502
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
508503
}
509504
out._Load(f);
510505
f.close();
511506
return out;
512507
}
508+
Bond Bond::Load(const char *fname) { return Bond::Load(string(fname)); }
513509

514510
void Bond::_Save(fstream &f) const {
515511
cytnx_error_msg(!f.is_open(), "[ERROR][Bond] invalid fstream%s", "\n");

src/Symmetry.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <iostream>
55
#include <string>
66
#include <vector>
7+
#include <filesystem>
78

89
using namespace std;
910

@@ -240,46 +241,40 @@ namespace cytnx {
240241

241242
void cytnx::Symmetry::Save(const std::string &fname) const {
242243
fstream f;
243-
f.open((fname + ".cysym"), ios::out | ios::trunc | ios::binary);
244-
if (!f.is_open()) {
245-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
244+
if (std::filesystem::path(fname).has_extension()) {
245+
// filename extension is given
246+
f.open(fname, ios::out | ios::trunc | ios::binary);
247+
} else {
248+
// add filename extension
249+
f.open((fname + ".cysym"), ios::out | ios::trunc | ios::binary);
246250
}
247-
this->_Save(f);
248-
f.close();
249-
}
250-
void cytnx::Symmetry::Save(const char *fname) const {
251-
fstream f;
252-
string ffname = string(fname) + ".cysym";
253-
f.open((ffname), ios::out | ios::trunc | ios::binary);
254251
if (!f.is_open()) {
255252
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
256253
}
257254
this->_Save(f);
258255
f.close();
259256
}
257+
void cytnx::Symmetry::Save(const char *fname) const { this->Save(string(fname)); }
260258

261259
cytnx::Symmetry cytnx::Symmetry::Load(const std::string &fname) {
262260
Symmetry out;
263261
fstream f;
264-
f.open(fname, ios::in | ios::binary);
262+
if (std::filesystem::path(fname).has_extension()) {
263+
// filename extension is given
264+
f.open(fname, ios::in | ios::binary);
265+
} else {
266+
// add filename extension
267+
f.open((fname + ".cysym"), ios::in | ios::binary);
268+
}
265269
if (!f.is_open()) {
266270
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
267271
}
268272
out._Load(f);
269273
f.close();
270274
return out;
271275
}
272-
273276
cytnx::Symmetry cytnx::Symmetry::Load(const char *fname) {
274-
Symmetry out;
275-
fstream f;
276-
f.open(fname, ios::in | ios::binary);
277-
if (!f.is_open()) {
278-
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
279-
}
280-
out._Load(f);
281-
f.close();
282-
return out;
277+
return cytnx::Symmetry::Load(string(fname));
283278
}
284279

285280
//==================

src/Tensor.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
#include <typeinfo>
21
#include "Tensor.hpp"
2+
3+
#include <typeinfo>
4+
#include <filesystem>
5+
36
#include "linalg.hpp"
47
#include "utils/is.hpp"
58
#include "Type.hpp"
9+
610
using namespace std;
711

812
#ifdef BACKEND_TORCH
@@ -446,23 +450,20 @@ namespace cytnx {
446450
}
447451
void Tensor::Save(const std::string &fname) const {
448452
fstream f;
449-
f.open((fname + ".cytn"), ios::out | ios::trunc | ios::binary);
450-
if (!f.is_open()) {
451-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
453+
if (std::filesystem::path(fname).has_extension()) {
454+
// filename extension is given
455+
f.open(fname, ios::out | ios::trunc | ios::binary);
456+
} else {
457+
// add filename extension
458+
f.open((fname + ".cytn"), ios::out | ios::trunc | ios::binary);
452459
}
453-
this->_Save(f);
454-
f.close();
455-
}
456-
void Tensor::Save(const char *fname) const {
457-
fstream f;
458-
string ffname = string(fname) + ".cytn";
459-
f.open(ffname, ios::out | ios::trunc | ios::binary);
460460
if (!f.is_open()) {
461461
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
462462
}
463463
this->_Save(f);
464464
f.close();
465465
}
466+
void Tensor::Save(const char *fname) const { this->Save(string(fname)); }
466467
void Tensor::_Save(fstream &f) const {
467468
// header
468469
// check:
@@ -493,25 +494,21 @@ namespace cytnx {
493494
Tensor Tensor::Load(const std::string &fname) {
494495
Tensor out;
495496
fstream f;
496-
f.open(fname, ios::in | ios::binary);
497-
if (!f.is_open()) {
498-
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
497+
if (std::filesystem::path(fname).has_extension()) {
498+
// filename extension is given
499+
f.open(fname, ios::in | ios::binary);
500+
} else {
501+
// add filename extension
502+
f.open((fname + ".cytn"), ios::in | ios::binary);
499503
}
500-
out._Load(f);
501-
f.close();
502-
return out;
503-
}
504-
Tensor Tensor::Load(const char *fname) {
505-
Tensor out;
506-
fstream f;
507-
f.open(fname, ios::in | ios::binary);
508504
if (!f.is_open()) {
509505
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
510506
}
511507
out._Load(f);
512508
f.close();
513509
return out;
514510
}
511+
Tensor Tensor::Load(const char *fname) { return Tensor::Load(string(fname)); }
515512
void Tensor::_Load(fstream &f) {
516513
// header
517514
// check:

src/UniTensor.cpp

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
#include <typeinfo>
21
#include "UniTensor.hpp"
3-
#include "utils/utils.hpp"
42

3+
#include <typeinfo>
4+
#include <filesystem>
5+
6+
#include "utils/utils.hpp"
57
#include "linalg.hpp"
68
#include "random.hpp"
79

@@ -156,46 +158,40 @@ namespace cytnx {
156158

157159
void UniTensor::Save(const std::string &fname) const {
158160
fstream f;
159-
f.open((fname + ".cytnx"), ios::out | ios::trunc | ios::binary);
160-
if (!f.is_open()) {
161-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
161+
if (std::filesystem::path(fname).has_extension()) {
162+
// filename extension is given
163+
f.open(fname, ios::out | ios::trunc | ios::binary);
164+
} else {
165+
// add filename extension
166+
f.open((fname + ".cytnx"), ios::out | ios::trunc | ios::binary);
162167
}
163-
this->_Save(f);
164-
f.close();
165-
}
166-
void UniTensor::Save(const char *fname) const {
167-
fstream f;
168-
string ffname = string(fname) + ".cytnx";
169-
f.open((ffname), ios::out | ios::trunc | ios::binary);
170168
if (!f.is_open()) {
171169
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
172170
}
173171
this->_Save(f);
174172
f.close();
175173
}
174+
void UniTensor::Save(const char *fname) const { Save(string(fname)); }
176175

177176
UniTensor UniTensor::Load(const std::string &fname) {
178177
UniTensor out;
179178
fstream f;
180-
f.open(fname, ios::in | ios::binary);
181-
if (!f.is_open()) {
182-
cytnx_error_msg(true, "[ERROR] invalid file path for load. >> %s\n", fname.c_str());
179+
if (std::filesystem::path(fname).has_extension()) {
180+
// filename extension is given
181+
f.open(fname, ios::in | ios::binary);
182+
} else {
183+
// add filename extension
184+
f.open((fname + ".cytnx"), ios::in | ios::binary);
183185
}
184-
out._Load(f);
185-
f.close();
186-
return out;
187-
}
188-
UniTensor UniTensor::Load(const char *fname) {
189-
UniTensor out;
190-
fstream f;
191-
f.open(fname, ios::in | ios::binary);
192186
if (!f.is_open()) {
193-
cytnx_error_msg(true, "[ERROR] invalid file path for load. >> %s\n", fname);
187+
cytnx_error_msg(true, "[ERROR] invalid file path for load. >> %s\n", fname.c_str());
194188
}
195189
out._Load(f);
196190
f.close();
197191
return out;
198192
}
193+
UniTensor UniTensor::Load(const char *fname) { return UniTensor::Load(string(fname)); }
194+
199195
// Random Generators:
200196
UniTensor UniTensor::normal(const cytnx_uint64 &Nelem, const double &mean, const double &std,
201197
const std::vector<std::string> &in_labels, const unsigned int &seed,

src/backend/Storage.cpp

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "backend/Storage.hpp"
22

33
#include <iostream>
4+
#include <filesystem>
45

56
using namespace std;
67

@@ -83,23 +84,20 @@ namespace cytnx {
8384

8485
void Storage::Save(const std::string &fname) const {
8586
fstream f;
86-
f.open((fname + ".cyst"), ios::out | ios::trunc | ios::binary);
87-
if (!f.is_open()) {
88-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
87+
if (std::filesystem::path(fname).has_extension()) {
88+
// filename extension is given
89+
f.open(fname, ios::out | ios::trunc | ios::binary);
90+
} else {
91+
// add filename extension
92+
f.open((fname + ".cyst"), ios::out | ios::trunc | ios::binary);
8993
}
90-
this->_Save(f);
91-
f.close();
92-
}
93-
void Storage::Save(const char *fname) const {
94-
fstream f;
95-
string ffname = string(fname) + ".cyst";
96-
f.open(ffname, ios::out | ios::trunc | ios::binary);
9794
if (!f.is_open()) {
9895
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
9996
}
10097
this->_Save(f);
10198
f.close();
10299
}
100+
void Storage::Save(const char *fname) const { this->Save(string(fname)); }
103101
void Storage::Tofile(const std::string &fname) const {
104102
fstream f;
105103
f.open(fname, ios::out | ios::trunc | ios::binary);
@@ -230,25 +228,21 @@ namespace cytnx {
230228
Storage Storage::Load(const std::string &fname) {
231229
Storage out;
232230
fstream f;
233-
f.open(fname, ios::in | ios::binary);
234-
if (!f.is_open()) {
235-
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
231+
if (std::filesystem::path(fname).has_extension()) {
232+
// filename extension is given
233+
f.open(fname, ios::in | ios::binary);
234+
} else {
235+
// add filename extension
236+
f.open((fname + ".cyst"), ios::in | ios::binary);
236237
}
237-
out._Load(f);
238-
f.close();
239-
return out;
240-
}
241-
Storage Storage::Load(const char *fname) {
242-
Storage out;
243-
fstream f;
244-
f.open(fname, ios::in | ios::binary);
245238
if (!f.is_open()) {
246239
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
247240
}
248241
out._Load(f);
249242
f.close();
250243
return out;
251244
}
245+
Storage Storage::Load(const char *fname) { return Storage::Load(string(fname)); }
252246
void Storage::_Load(fstream &f) {
253247
// header
254248
unsigned long long sz;

0 commit comments

Comments
 (0)