Skip to content

Commit 8b61cd0

Browse files
committed
added init of weights
1 parent eb2cb94 commit 8b61cd0

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

chebai/models/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def _execute(
263263
d = dict(data=data, labels=labels, output=model_output, preds=pr)
264264
if log:
265265
if self.criterion is not None:
266+
f.init_weights()
266267
loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss(
267268
model_output, labels, data.get("loss_kwargs", dict())
268269
)

extras/weight_loader.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
import csv
22
import torch
3+
import os
34

45

56
#inint weights in a csv file
6-
def init_weights(path="/home/programmer/Bachelorarbeit/weights/first_it.csv",path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
7-
with open(path_to_split, 'r') as csvfile:
8-
with open(path, 'w') as to_file:
9-
fieldnames = ['idents','label','weights']
10-
writer = csv.writer(to_file)
11-
writer.writerow(fieldnames)
12-
reader = csv.reader(csvfile)
13-
weight = 1 / get_size(path_to_split)
14-
for row in reader:
15-
if row[1] == "train" or row[1] == "validation":
16-
#print(type(row[0]))
17-
writer.writerow([int(row[0]),row[1],weight])
18-
19-
def mock_init_weights(path="/home/programmer/Bachelorarbeit/weights/first_it.csv",path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
7+
def init_weights(path="../weights/first_it.csv",path_to_split="../split/splits.csv"):
8+
if not os.path.exists("../weights/first_it.csv"):
9+
with open(path_to_split, 'r') as csvfile:
10+
with open(path, 'w') as to_file:
11+
fieldnames = ['idents','label','weights']
12+
writer = csv.writer(to_file)
13+
writer.writerow(fieldnames)
14+
reader = csv.reader(csvfile)
15+
weight = 1 / get_size(path_to_split)
16+
for row in reader:
17+
if row[1] == "train" or row[1] == "validation":
18+
#print(type(row[0]))
19+
writer.writerow([int(row[0]),row[1],weight])
20+
21+
def mock_init_weights(path="../weights/first_it.csv",path_to_split="../split/splits.csv"):
2022
with open(path_to_split, 'r') as csvfile:
2123
with open(path, 'w') as to_file:
2224
fieldnames = ['idents','label','weights']
@@ -31,7 +33,7 @@ def mock_init_weights(path="/home/programmer/Bachelorarbeit/weights/first_it.csv
3133

3234
#check the size of a csv file given a filter for the second object
3335
# assumes csv file has a header
34-
def get_size(path="/home/programmer/Bachelorarbeit/split/splits.csv",filter=["train"]) -> int:
36+
def get_size(path="../split/splits.csv",filter=["train"]) -> int:
3537
with open(path,'r') as file:
3638
reader = csv.reader(file)
3739
size = -1
@@ -40,15 +42,13 @@ def get_size(path="/home/programmer/Bachelorarbeit/split/splits.csv",filter=["tr
4042
size = size + 1
4143
return size
4244
#get a dictory with the ids and weights of the data points
43-
def get_weights(idents:tuple[int,...],path="/home/programmer/Bachelorarbeit/weights/first_it.csv")-> dict[str,float]:
45+
def get_weights(idents:tuple[int,...],path="../weights/first_it.csv")-> dict[str,float]:
4446
value = dict()
4547
for i in idents:
4648
weight = find_weight(path,i)
4749
value.update({str(i):weight})
4850
return value
4951

50-
51-
5252
#finds the weight for a specific datapoint
5353
def find_weight(path:str,ident:int)-> float:
5454
with open(path,'r') as file:
@@ -60,7 +60,7 @@ def find_weight(path:str,ident:int)-> float:
6060
label = find_label(id=ident)
6161
print(f"{ident} is not in file with {label} ")
6262

63-
def find_label(id:int,path="/home/programmer/Bachelorarbeit/split/splits.csv")-> str:
63+
def find_label(id:int,path="../split/splits.csv")-> str:
6464
with open(path,'r') as file:
6565
reader = csv.reader(file)
6666
for row in reader:
@@ -81,8 +81,6 @@ def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tup
8181
index = index + 1
8282
return weight
8383

84-
85-
8684
def testing():
8785
print("hello world")
8886

@@ -99,6 +97,6 @@ def create_weight_tensor(weight:float)-> torch.tensor:
9997
def create_class_weights()-> torch.tensor:
10098
pass
10199

102-
mock_init_weights()
100+
#mock_init_weights()
103101
# print(get_weights((233713,51990)))
104102

0 commit comments

Comments
 (0)