Skip to content

Commit fa1f0b9

Browse files
committed
solves the error of using the collate.py for the eval of the model
1 parent 3fdcf59 commit fa1f0b9

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

chebai/preprocessing/collate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
105105
lens = torch.tensor(list(map(len, x)))
106106
model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None]
107107
model_kwargs["lens"] = lens
108-
for d in data:
109-
id = d["ident"]
110-
weight = d["weight"]
111-
loss_kwargs[str(id)] = weight
108+
if "weight" in data[0]:
109+
for d in data:
110+
id = d["ident"]
111+
weight = d["weight"]
112+
loss_kwargs[str(id)] = weight
112113

113114

114115
return XYData(

0 commit comments

Comments
 (0)