Skip to content

Commit 8d1f7d0

Browse files
committed
efficient weight loading
1 parent c745dfe commit 8d1f7d0

File tree

5 files changed

+25
-17
lines changed

5 files changed

+25
-17
lines changed

chebai/models/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,15 @@ def _execute(
261261
model_output = self(data, **data.get("model_kwargs", dict()))
262262
pr, tar = self._get_prediction_and_labels(data, labels, model_output)
263263
d = dict(data=data, labels=labels, output=model_output, preds=pr)
264-
torch.save(d,"d.pt")
264+
# torch.save(data["loss_kwargs"],"kloss.pt")
265+
# torch.save(data["idents"],"id.pt")
265266
if log:
266267
if self.criterion is not None:
267268
loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss(
268269
model_output, labels, data.get("loss_kwargs", dict())
269270
)
270271
loss_kwargs = dict()
272+
loss_kwargs['weights'] = f.create_data_weights(batchsize=len(data['idents']),dim=data['labels'].size(dim=1),weights=data["loss_kwargs"],idents=data["idents"])
271273
if self.pass_loss_kwargs:
272274
loss_kwargs = loss_kwargs_candidates
273275
torch.save(loss_data,"loss_data.pt")

chebai/preprocessing/bin/smiles_token/tokens.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,3 +984,5 @@ p
984984
[ClH2+]
985985
[BrH2+]
986986
[IH2+]
987+
[NH3]
988+
[OH2]

chebai/preprocessing/collate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
108108
for d in data:
109109
id = d["ident"]
110110
weight = d["weight"]
111-
loss_kwargs["ident"] = weight
111+
loss_kwargs[str(id)] = weight
112112

113113

114114
return XYData(

chebai/preprocessing/datasets/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,11 +1184,11 @@ def load_processed_data(
11841184
data_df = self.dynamic_split_dfs[kind]
11851185
data = data_df.to_dict(orient="records")
11861186
if kind == "train":
1187-
data = f.add_train_weights(data)
1187+
f.init_weights()
1188+
data = f.add_val_weights(data)
11881189
if kind == "validation":
1189-
print(kind)
11901190
data = f.add_val_weights(data)
1191-
torch.save(data,"gewicht.pt")
1191+
# torch.save(data,"gewicht.pt")
11921192

11931193
return data
11941194

extras/weight_loader.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,16 @@ def find_label(id:int,path="../split/splits.csv")-> str:
7171

7272
#to do
7373
# return should be a tuple of weigths matching the sequenece of the target and label tensor
74-
# def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tuple[int,...])-> torch.tensor:
75-
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76-
# weight = torch.empty(batchsize,dim,device=device)
77-
# index = 0
78-
# for i in idents:
79-
# w = weights[str(i)]
80-
# for j in range(0,dim):
81-
# weight[index][j] = float(w)
82-
# index = index + 1
83-
# return weight
74+
def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tuple[int,...])-> torch.tensor:
75+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76+
weight = torch.empty(batchsize,dim,device=device)
77+
index = 0
78+
for i in idents:
79+
w = weights[str(i)]
80+
for j in range(0,dim):
81+
weight[index][j] = float(w)
82+
index = index + 1
83+
return weight
8484

8585
def testing():
8686
print("hello world")
@@ -97,17 +97,21 @@ def add_val_weights(ids):
9797

9898

9999
def add_train_weights(ids):
100+
it = 0
100101
for i in ids:
102+
if it % 10000 == 0:
103+
print(it)
101104
ident = i["ident"]
102105
weight = find_weight("/home/programmer/Bachelorarbeit/weights/first_it.csv",ident=ident)
103106
i["weight"] = weight
107+
it = it +1
104108
return ids
105109

106110
def check_weights(data):
107111
for i in data:
108112
print(f"({i["ident"]} , {i["weight"]}")
109113

110-
111-
#mock_init_weights()
114+
init_weights()
115+
mock_init_weights()
112116
# print(get_weights((233713,51990)))
113117

0 commit comments

Comments
 (0)