Skip to content

Commit c745dfe

Browse files
committed
sync
1 parent 2d54d0a commit c745dfe

File tree

4 files changed

+52
-22
lines changed

4 files changed

+52
-22
lines changed

chebai/models/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,17 +261,16 @@ def _execute(
261261
model_output = self(data, **data.get("model_kwargs", dict()))
262262
pr, tar = self._get_prediction_and_labels(data, labels, model_output)
263263
d = dict(data=data, labels=labels, output=model_output, preds=pr)
264+
torch.save(d,"d.pt")
264265
if log:
265266
if self.criterion is not None:
266-
f.init_weights()
267267
loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss(
268268
model_output, labels, data.get("loss_kwargs", dict())
269269
)
270270
loss_kwargs = dict()
271271
if self.pass_loss_kwargs:
272272
loss_kwargs = loss_kwargs_candidates
273-
dict_weights = f.get_weights(data['idents'])
274-
loss_kwargs['weights'] = f.create_data_weights(len(data['idents']),data['labels'].size(dim=1),dict_weights,data['idents'])
273+
torch.save(loss_data,"loss_data.pt")
275274
loss_kwargs["current_epoch"] = self.trainer.current_epoch
276275
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
277276
if isinstance(loss, tuple):

chebai/preprocessing/collate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
7777
model_kwargs: Dict = dict()
7878
# Indices of non-null labels are stored in key `non_null_labels` of loss_kwargs.
7979
loss_kwargs: Dict = dict()
80-
8180
if isinstance(data[0], tuple):
8281
# For legacy data
8382
x, y, idents = zip(*data)
@@ -106,6 +105,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
106105
lens = torch.tensor(list(map(len, x)))
107106
model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None]
108107
model_kwargs["lens"] = lens
108+
for d in data:
109+
id = d["ident"]
110+
weight = d["weight"]
111+
loss_kwargs["ident"] = weight
112+
109113

110114
return XYData(
111115
pad_sequence([torch.tensor(a) for a in x], batch_first=True),

chebai/preprocessing/datasets/base.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from torch.utils.data import DataLoader
1919

2020
from chebai.preprocessing import reader as dr
21+
import sys
22+
sys.path.insert(1,'/home/programmer/Bachelorarbeit/python-chebai')
23+
24+
import extras.weight_loader as f
2125

2226

2327
class XYBaseDataModule(LightningDataModule):
@@ -1123,6 +1127,7 @@ def _retrieve_splits_from_csv(self) -> None:
11231127
os.path.join(self.processed_dir, filename)
11241128
)
11251129
df_data = pd.DataFrame(data)
1130+
11261131

11271132
train_ids = splits_df[splits_df["split"] == "train"]["id"]
11281133
validation_ids = splits_df[splits_df["split"] == "validation"]["id"]
@@ -1165,6 +1170,9 @@ def load_processed_data(
11651170
raise ValueError(
11661171
"Either kind or filename is required to load the correct dataset, both are None"
11671172
)
1173+
if kind == "train":
1174+
print("loading train data")
1175+
11681176

11691177
# If both kind and filename are given, use filename
11701178
if kind is not None and filename is None:
@@ -1174,10 +1182,19 @@ def load_processed_data(
11741182
]
11751183
else:
11761184
data_df = self.dynamic_split_dfs[kind]
1177-
return data_df.to_dict(orient="records")
1185+
data = data_df.to_dict(orient="records")
1186+
if kind == "train":
1187+
data = f.add_train_weights(data)
1188+
if kind == "validation":
1189+
print(kind)
1190+
data = f.add_val_weights(data)
1191+
torch.save(data,"gewicht.pt")
1192+
1193+
return data
11781194

11791195
# If filename is provided
1180-
return self.load_processed_data_from_file(filename)
1196+
data = self.load_processed_data_from_file(filename)
1197+
return data
11811198

11821199
def load_processed_data_from_file(self, filename):
11831200
return torch.load(os.path.join(filename), weights_only=False)

extras/weight_loader.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def init_weights(path="../weights/first_it.csv",path_to_split="../split/splits.c
1515
reader = csv.reader(csvfile)
1616
weight = 1 / get_size(path_to_split)
1717
for row in reader:
18-
if row[1] == "train" or row[1] == "validation":
18+
if row[1] == "train":
1919
#print(type(row[0]))
2020
writer.writerow([int(row[0]),row[1],weight])
2121

@@ -28,7 +28,7 @@ def mock_init_weights(path="../weights/first_it.csv",path_to_split="../split/spl
2828
reader = csv.reader(csvfile)
2929
weight = 1
3030
for row in reader:
31-
if row[1] == "train" or row[1] == "validation":
31+
if row[1] == "train":
3232
writer.writerow([int(row[0]),row[1],weight])
3333
weight = weight + 1
3434

@@ -71,16 +71,16 @@ def find_label(id:int,path="../split/splits.csv")-> str:
7171

7272
#to do
7373
# return should be a tuple of weigths matching the sequenece of the target and label tensor
74-
def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tuple[int,...])-> torch.tensor:
75-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76-
weight = torch.empty(batchsize,dim,device=device)
77-
index = 0
78-
for i in idents:
79-
w = weights[str(i)]
80-
for j in range(0,dim):
81-
weight[index][j] = float(w)
82-
index = index + 1
83-
return weight
74+
# def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tuple[int,...])-> torch.tensor:
75+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76+
# weight = torch.empty(batchsize,dim,device=device)
77+
# index = 0
78+
# for i in idents:
79+
# w = weights[str(i)]
80+
# for j in range(0,dim):
81+
# weight[index][j] = float(w)
82+
# index = index + 1
83+
# return weight
8484

8585
def testing():
8686
print("hello world")
@@ -89,15 +89,25 @@ def testing():
8989
def create_weight_tensor(weight:float)-> torch.tensor:
9090
pass
9191

92+
def add_val_weights(ids):
93+
for i in ids:
94+
weight = 1
95+
i["weight"] = weight
96+
return ids
9297

9398

99+
def add_train_weights(ids):
100+
for i in ids:
101+
ident = i["ident"]
102+
weight = find_weight("/home/programmer/Bachelorarbeit/weights/first_it.csv",ident=ident)
103+
i["weight"] = weight
104+
return ids
94105

106+
def check_weights(data):
107+
for i in data:
108+
print(f"({i["ident"]} , {i["weight"]}")
95109

96110

97-
98-
def create_class_weights()-> torch.tensor:
99-
pass
100-
101111
#mock_init_weights()
102112
# print(get_weights((233713,51990)))
103113

0 commit comments

Comments
 (0)