Skip to content

Commit 8700a15

Browse files
committed
file extensions in Save() and Load() methods are optional
if no file extension is given, the standard extension for the corresponding object type is added to the filename this is backwards compatible with the previous behavior but allows to use custom file extensions and call Save() and Load() with the exact same arguments
1 parent c5dd9c4 commit 8700a15

File tree

11 files changed

+152
-132
lines changed

11 files changed

+152
-132
lines changed

src/Bond.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11

22
#include "Bond.hpp"
3+
34
#include <algorithm>
5+
#include <filesystem>
6+
47
#include "utils/utils.hpp"
8+
59
using namespace std;
610

711
namespace cytnx {
@@ -469,46 +473,42 @@ namespace cytnx {
469473

470474
void Bond::Save(const std::string &fname) const {
471475
fstream f;
472-
f.open((fname + ".cybd"), ios::out | ios::trunc | ios::binary);
476+
if (std::filesystem::path(fname).has_extension()) {
477+
// filename extension is given
478+
f.open(fname, ios::out | ios::trunc | ios::binary);
479+
} else {
480+
// add filename extension
481+
f.open((fname + ".cybd"), ios::out | ios::trunc | ios::binary);
482+
}
473483
if (!f.is_open()) {
474484
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
475485
}
476486
this->_Save(f);
477487
f.close();
478488
}
479489
void Bond::Save(const char *fname) const {
480-
fstream f;
481-
string ffname = string(fname) + ".cybd";
482-
f.open((ffname), ios::out | ios::trunc | ios::binary);
483-
if (!f.is_open()) {
484-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
485-
}
486-
this->_Save(f);
487-
f.close();
490+
this->Save(string(fname));
488491
}
489492

490493
Bond Bond::Load(const std::string &fname) {
491494
Bond out;
492495
fstream f;
493-
f.open(fname, ios::in | ios::binary);
496+
if (std::filesystem::path(fname).has_extension()) {
497+
// filename extension is given
498+
f.open(fname, ios::in | ios::binary);
499+
} else {
500+
// add filename extension
501+
f.open((fname + ".cybd"), ios::in | ios::binary);
502+
}
494503
if (!f.is_open()) {
495504
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
496505
}
497506
out._Load(f);
498507
f.close();
499508
return out;
500509
}
501-
502510
Bond Bond::Load(const char *fname) {
503-
Bond out;
504-
fstream f;
505-
f.open(fname, ios::in | ios::binary);
506-
if (!f.is_open()) {
507-
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
508-
}
509-
out._Load(f);
510-
f.close();
511-
return out;
511+
return Bond::Load(string(fname));
512512
}
513513

514514
void Bond::_Save(fstream &f) const {

src/Symmetry.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <iostream>
55
#include <string>
66
#include <vector>
7+
#include <filesystem>
78

89
using namespace std;
910

@@ -240,46 +241,42 @@ namespace cytnx {
240241

241242
void cytnx::Symmetry::Save(const std::string &fname) const {
242243
fstream f;
243-
f.open((fname + ".cysym"), ios::out | ios::trunc | ios::binary);
244+
if (std::filesystem::path(fname).has_extension()) {
245+
// filename extension is given
246+
f.open(fname, ios::out | ios::trunc | ios::binary);
247+
} else {
248+
// add filename extension
249+
f.open((fname + ".cysym"), ios::out | ios::trunc | ios::binary);
250+
}
244251
if (!f.is_open()) {
245252
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
246253
}
247254
this->_Save(f);
248255
f.close();
249256
}
250257
void cytnx::Symmetry::Save(const char *fname) const {
251-
fstream f;
252-
string ffname = string(fname) + ".cysym";
253-
f.open((ffname), ios::out | ios::trunc | ios::binary);
254-
if (!f.is_open()) {
255-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
256-
}
257-
this->_Save(f);
258-
f.close();
258+
this->Save(string(fname));
259259
}
260260

261261
cytnx::Symmetry cytnx::Symmetry::Load(const std::string &fname) {
262262
Symmetry out;
263263
fstream f;
264-
f.open(fname, ios::in | ios::binary);
264+
if (std::filesystem::path(fname).has_extension()) {
265+
// filename extension is given
266+
f.open(fname, ios::in | ios::binary);
267+
} else {
268+
// add filename extension
269+
f.open((fname + ".cysym"), ios::in | ios::binary);
270+
}
265271
if (!f.is_open()) {
266272
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
267273
}
268274
out._Load(f);
269275
f.close();
270276
return out;
271277
}
272-
273278
cytnx::Symmetry cytnx::Symmetry::Load(const char *fname) {
274-
Symmetry out;
275-
fstream f;
276-
f.open(fname, ios::in | ios::binary);
277-
if (!f.is_open()) {
278-
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
279-
}
280-
out._Load(f);
281-
f.close();
282-
return out;
279+
return cytnx::Symmetry::Load(string(fname));
283280
}
284281

285282
//==================

src/Tensor.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
#include <typeinfo>
21
#include "Tensor.hpp"
2+
3+
#include <typeinfo>
4+
#include <filesystem>
5+
36
#include "linalg.hpp"
47
#include "utils/is.hpp"
58
#include "Type.hpp"
9+
610
using namespace std;
711

812
#ifdef BACKEND_TORCH
@@ -446,22 +450,21 @@ namespace cytnx {
446450
}
447451
void Tensor::Save(const std::string &fname) const {
448452
fstream f;
449-
f.open((fname + ".cytn"), ios::out | ios::trunc | ios::binary);
453+
if (std::filesystem::path(fname).has_extension()) {
454+
// filename extension is given
455+
f.open(fname, ios::out | ios::trunc | ios::binary);
456+
} else {
457+
// add filename extension
458+
f.open((fname + ".cytn"), ios::out | ios::trunc | ios::binary);
459+
}
450460
if (!f.is_open()) {
451461
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
452462
}
453463
this->_Save(f);
454464
f.close();
455465
}
456466
void Tensor::Save(const char *fname) const {
457-
fstream f;
458-
string ffname = string(fname) + ".cytn";
459-
f.open(ffname, ios::out | ios::trunc | ios::binary);
460-
if (!f.is_open()) {
461-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
462-
}
463-
this->_Save(f);
464-
f.close();
467+
this->Save(string(fname));
465468
}
466469
void Tensor::_Save(fstream &f) const {
467470
// header
@@ -493,7 +496,13 @@ namespace cytnx {
493496
Tensor Tensor::Load(const std::string &fname) {
494497
Tensor out;
495498
fstream f;
496-
f.open(fname, ios::in | ios::binary);
499+
if (std::filesystem::path(fname).has_extension()) {
500+
// filename extension is given
501+
f.open(fname, ios::in | ios::binary);
502+
} else {
503+
// add filename extension
504+
f.open((fname + ".cytn"), ios::in | ios::binary);
505+
}
497506
if (!f.is_open()) {
498507
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
499508
}
@@ -502,15 +511,7 @@ namespace cytnx {
502511
return out;
503512
}
504513
Tensor Tensor::Load(const char *fname) {
505-
Tensor out;
506-
fstream f;
507-
f.open(fname, ios::in | ios::binary);
508-
if (!f.is_open()) {
509-
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
510-
}
511-
out._Load(f);
512-
f.close();
513-
return out;
514+
return Tensor::Load(string(fname));
514515
}
515516
void Tensor::_Load(fstream &f) {
516517
// header

src/UniTensor.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
#include <typeinfo>
21
#include "UniTensor.hpp"
3-
#include "utils/utils.hpp"
42

3+
#include <typeinfo>
4+
#include <filesystem>
5+
6+
#include "utils/utils.hpp"
57
#include "linalg.hpp"
68
#include "random.hpp"
79

@@ -156,28 +158,33 @@ namespace cytnx {
156158

157159
void UniTensor::Save(const std::string &fname) const {
158160
fstream f;
159-
f.open((fname + ".cytnx"), ios::out | ios::trunc | ios::binary);
161+
if (std::filesystem::path(fname).has_extension()) {
162+
// filename extension is given
163+
f.open(fname, ios::out | ios::trunc | ios::binary);
164+
} else {
165+
// add filename extension
166+
f.open((fname + ".cytnx"), ios::out | ios::trunc | ios::binary);
167+
}
160168
if (!f.is_open()) {
161169
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
162170
}
163171
this->_Save(f);
164172
f.close();
165173
}
166174
void UniTensor::Save(const char *fname) const {
167-
fstream f;
168-
string ffname = string(fname) + ".cytnx";
169-
f.open((ffname), ios::out | ios::trunc | ios::binary);
170-
if (!f.is_open()) {
171-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
172-
}
173-
this->_Save(f);
174-
f.close();
175+
Save(string(fname));
175176
}
176177

177178
UniTensor UniTensor::Load(const std::string &fname) {
178179
UniTensor out;
179180
fstream f;
180-
f.open(fname, ios::in | ios::binary);
181+
if (std::filesystem::path(fname).has_extension()) {
182+
// filename extension is given
183+
f.open(fname, ios::in | ios::binary);
184+
} else {
185+
// add filename extension
186+
f.open((fname + ".cytnx"), ios::in | ios::binary);
187+
}
181188
if (!f.is_open()) {
182189
cytnx_error_msg(true, "[ERROR] invalid file path for load. >> %s\n", fname.c_str());
183190
}
@@ -186,16 +193,9 @@ namespace cytnx {
186193
return out;
187194
}
188195
UniTensor UniTensor::Load(const char *fname) {
189-
UniTensor out;
190-
fstream f;
191-
f.open(fname, ios::in | ios::binary);
192-
if (!f.is_open()) {
193-
cytnx_error_msg(true, "[ERROR] invalid file path for load. >> %s\n", fname);
194-
}
195-
out._Load(f);
196-
f.close();
197-
return out;
196+
return UniTensor::Load(string(fname));
198197
}
198+
199199
// Random Generators:
200200
UniTensor UniTensor::normal(const cytnx_uint64 &Nelem, const double &mean, const double &std,
201201
const std::vector<std::string> &in_labels, const unsigned int &seed,

src/backend/Storage.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "backend/Storage.hpp"
22

33
#include <iostream>
4+
#include <filesystem>
45

56
using namespace std;
67

@@ -83,22 +84,21 @@ namespace cytnx {
8384

8485
void Storage::Save(const std::string &fname) const {
8586
fstream f;
86-
f.open((fname + ".cyst"), ios::out | ios::trunc | ios::binary);
87+
if (std::filesystem::path(fname).has_extension()) {
88+
// filename extension is given
89+
f.open(fname, ios::out | ios::trunc | ios::binary);
90+
} else {
91+
// add filename extension
92+
f.open((fname + ".cyst"), ios::out | ios::trunc | ios::binary);
93+
}
8794
if (!f.is_open()) {
8895
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
8996
}
9097
this->_Save(f);
9198
f.close();
9299
}
93100
void Storage::Save(const char *fname) const {
94-
fstream f;
95-
string ffname = string(fname) + ".cyst";
96-
f.open(ffname, ios::out | ios::trunc | ios::binary);
97-
if (!f.is_open()) {
98-
cytnx_error_msg(true, "[ERROR] invalid file path for save.%s", "\n");
99-
}
100-
this->_Save(f);
101-
f.close();
101+
this->Save(string(fname));
102102
}
103103
void Storage::Tofile(const std::string &fname) const {
104104
fstream f;
@@ -230,7 +230,13 @@ namespace cytnx {
230230
Storage Storage::Load(const std::string &fname) {
231231
Storage out;
232232
fstream f;
233-
f.open(fname, ios::in | ios::binary);
233+
if (std::filesystem::path(fname).has_extension()) {
234+
// filename extension is given
235+
f.open(fname, ios::in | ios::binary);
236+
} else {
237+
// add filename extension
238+
f.open((fname + ".cyst"), ios::in | ios::binary);
239+
}
234240
if (!f.is_open()) {
235241
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
236242
}
@@ -239,15 +245,7 @@ namespace cytnx {
239245
return out;
240246
}
241247
Storage Storage::Load(const char *fname) {
242-
Storage out;
243-
fstream f;
244-
f.open(fname, ios::in | ios::binary);
245-
if (!f.is_open()) {
246-
cytnx_error_msg(true, "[ERROR] invalid file path for load.%s", "\n");
247-
}
248-
out._Load(f);
249-
f.close();
250-
return out;
248+
return Storage::Load(string(fname));
251249
}
252250
void Storage::_Load(fstream &f) {
253251
// header

0 commit comments

Comments
 (0)