Skip to content

Commit e62fa84

Browse files
committed
restore the original code
1 parent 6d2d220 commit e62fa84

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

chebai/models/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,13 @@ def _process_batch(self, batch: XYData, batch_idx: int) -> Dict[str, Any]:
108108
Returns:
109109
Dict[str, Any]: Processed batch data.
110110
"""
111-
return {
112-
"features": batch.x,
113-
"labels": self._process_labels_in_batch(batch),
114-
"model_kwargs": batch.additional_fields.get("model_kwargs", {}),
115-
"loss_kwargs": batch.additional_fields.get("loss_kwargs", {}),
116-
"idents": batch.additional_fields.get("idents", []),
117-
}
111+
return dict(
112+
features=batch.x,
113+
labels=self._process_labels_in_batch(batch),
114+
model_kwargs=batch.additional_fields["model_kwargs"],
115+
loss_kwargs=batch.additional_fields["loss_kwargs"],
116+
idents=batch.additional_fields["idents"],
117+
)
118118

119119
def _process_for_loss(
120120
self,

0 commit comments

Comments
 (0)