Skip to content

Commit eb2cb94

Browse files
committed
boost loss without init of weights
1 parent 4c76fec commit eb2cb94

File tree

4 files changed

+27
-16
lines changed

4 files changed

+27
-16
lines changed

chebai/loss/boost_bce.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import extras.weight_loader as f
66

77

8-
class BCE_Point_boosting(torch.nn.BCEWithLogitsLoss):
8+
class BCE_Boosting(torch.nn.BCEWithLogitsLoss):
99

1010
def __init__(
1111
self,
1212
**kwargs
1313
):
14-
super().__init__(reduction=None,**kwargs)
14+
super().__init__(reduction='none',**kwargs)
1515

1616
def forward(
1717
self,
@@ -22,11 +22,5 @@ def forward(
2222
)-> torch.Tensor:
2323
weights = kwargs['weights']
2424
loss = super().forward(input=input,target=target)
25-
weights_tensor = f.create_weight_tensor(weights)
26-
loss_scaled = torch.matmul(weights_tensor,loss)
27-
return torch.mean(loss_scaled)
28-
29-
30-
31-
32-
25+
loss_scaled = loss * weights
26+
return torch.mean(loss_scaled)

chebai/models/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77

88
from chebai.preprocessing.structures import XYData
99

10+
import sys
11+
sys.path.insert(1,'/home/programmer/Bachelorarbeit/python-chebai')
12+
13+
import extras.weight_loader as f
14+
1015
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
1116

1217
_MODEL_REGISTRY = dict()
@@ -264,6 +269,8 @@ def _execute(
264269
loss_kwargs = dict()
265270
if self.pass_loss_kwargs:
266271
loss_kwargs = loss_kwargs_candidates
272+
dict_weights = f.get_weights(data['idents'])
273+
loss_kwargs['weights'] = f.create_data_weights(len(data['idents']),data['labels'].size(dim=1),dict_weights,data['idents'])
267274
loss_kwargs["current_epoch"] = self.trainer.current_epoch
268275
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
269276
if isinstance(loss, tuple):

configs/loss/boost_bce.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
class_path: chebai.loss.boost_bce.BCE_Boosting

extras/weight_loader.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def mock_init_weights(path="/home/programmer/Bachelorarbeit/weights/first_it.csv
3131

3232
#check the size of a csv file given a filter for the second object
3333
# assumes csv file has a header
34-
def get_size(path="/home/programmer/Bachelorarbeit/split/splits.csv",filter=["train","validation"]) -> int:
34+
def get_size(path="/home/programmer/Bachelorarbeit/split/splits.csv",filter=["train"]) -> int:
3535
with open(path,'r') as file:
3636
reader = csv.reader(file)
3737
size = -1
@@ -57,18 +57,27 @@ def find_weight(path:str,ident:int)-> float:
5757
if row[0] == str(ident):
5858
return float(row[2])
5959

60-
61-
print(f"{ident} is not in file ")
60+
label = find_label(id=ident)
61+
print(f"{ident} is not in file with {label} ")
62+
63+
def find_label(id:int,path="/home/programmer/Bachelorarbeit/split/splits.csv")-> str:
64+
with open(path,'r') as file:
65+
reader = csv.reader(file)
66+
for row in reader:
67+
if row[0] == str(id):
68+
return row[1]
69+
6270

6371
#to do
6472
# return should be a tuple of weigths matching the sequenece of the target and label tensor
6573
def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tuple[int,...])-> torch.tensor:
66-
weight = torch.empty(batchsize,dim)
74+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
75+
weight = torch.empty(batchsize,dim,device=device)
6776
index = 0
6877
for i in idents:
6978
w = weights[str(i)]
7079
for j in range(0,dim):
71-
weight[index][j] = w
80+
weight[index][j] = float(w)
7281
index = index + 1
7382
return weight
7483

@@ -90,6 +99,6 @@ def create_weight_tensor(weight:float)-> torch.tensor:
9099
def create_class_weights()-> torch.tensor:
91100
pass
92101

93-
# mock_init_weights()
102+
mock_init_weights()
94103
# print(get_weights((233713,51990)))
95104

0 commit comments

Comments
 (0)