Skip to content

Commit 4c76fec

Browse files
committed
sync commit
1 parent 7721d8c commit 4c76fec

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

chebai/loss/boost_bce.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import sys
3+
sys.path.insert(1,'/home/programmer/Bachelorarbeit/python-chebai')
4+
5+
import extras.weight_loader as f
6+
7+
8+
class BCE_Point_boosting(torch.nn.BCEWithLogitsLoss):
9+
10+
def __init__(
11+
self,
12+
**kwargs
13+
):
14+
super().__init__(reduction=None,**kwargs)
15+
16+
def forward(
17+
self,
18+
input: torch.Tensor,
19+
target: torch.Tensor,
20+
**kwargs
21+
22+
)-> torch.Tensor:
23+
weights = kwargs['weights']
24+
loss = super().forward(input=input,target=target)
25+
weights_tensor = f.create_weight_tensor(weights)
26+
loss_scaled = torch.matmul(weights_tensor,loss)
27+
return torch.mean(loss_scaled)
28+
29+
30+
31+
32+

extras/weight_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def find_weight(path:str,ident:int)-> float:
6060

6161
print(f"{ident} is not in file ")
6262

63-
63+
#to do
64+
# return should be a tuple of weigths matching the sequenece of the target and label tensor
6465
def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tuple[int,...])-> torch.tensor:
6566
weight = torch.empty(batchsize,dim)
6667
index = 0
@@ -76,6 +77,9 @@ def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tup
7677
def testing():
7778
print("hello world")
7879

80+
#create a tensor that is size (1,n) where n is the amout of classes being predicted
81+
def create_weight_tensor(weight:float)-> torch.tensor:
82+
pass
7983

8084

8185

0 commit comments

Comments
 (0)