Skip to content

Commit a81967a

Browse files
committed
core: save and load json
1 parent 33df93e commit a81967a

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

topoloss/core.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import os
12
import torch.nn as nn
23
from einops import rearrange
34
from typing import Union
5+
import json
46
from .utils.getting_modules import get_layer_by_name
57
from .cortical_sheet.output import get_cortical_sheet_conv, get_cortical_sheet_linear
68
from .cortical_sheet.input import get_cortical_sheet_linear_input
@@ -108,3 +110,27 @@ def compute(self, model, reduce_mean=True, do_scaling=True):
108110
return sum(loss_values) / len(loss_values)
109111
else:
110112
return layer_wise_losses
113+
114+
def save_json(self, filename: str):
115+
116+
data = []
117+
for loss in self.losses:
118+
d = loss.__dict__
119+
d['name'] = loss.__class__.__name__
120+
data.append(d)
121+
122+
with open(filename, "w") as f:
123+
json.dump(data, f, indent=4)
124+
125+
@classmethod
126+
def from_json(cls, filename: str):
127+
assert os.path.exists(filename), f"File not found: {filename}"
128+
with open(filename, "r") as f:
129+
data = json.load(f)
130+
assert isinstance(data, list), f"Expected data to be a list but got: {type(data)}"
131+
losses = []
132+
for d in data:
133+
name = d.pop("name")
134+
losses.append(globals()[name](**d))
135+
return cls(losses=losses)
136+

0 commit comments

Comments
 (0)