Skip to content

Commit eb86e3f

Browse files
committed
cast to model device
1 parent 4f506b1 commit eb86e3f

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

chebai/result/generate_class_properties.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,18 @@ def generate_props(
172172
class_names = self.load_class_labels(classes_file)
173173
num_classes = len(class_names)
174174
metrics_obj_dict: dict[str, torchmetrics.Metric] = {
175-
"cm": MultilabelConfusionMatrix(num_labels=num_classes),
176-
"f1": MultilabelF1Score(num_labels=num_classes, average=None),
175+
"cm": MultilabelConfusionMatrix(num_labels=num_classes).to(
176+
device=model.device
177+
),
178+
"f1": MultilabelF1Score(num_labels=num_classes, average=None).to(
179+
device=model.device
180+
),
177181
}
178182

179183
for batch_idx, batch in enumerate(data_loader):
180184
data = model._process_batch(batch, batch_idx=batch_idx)
181-
labels = data["labels"]
185+
labels = data["labels"].to(device=model.device)
186+
data["features"][0].to(device=model.device)
182187
model_output = model(data, **data.get("model_kwargs", {}))
183188
preds, targets = model._get_prediction_and_labels(
184189
data, labels, model_output
@@ -241,7 +246,8 @@ def generate(
241246

242247

243248
if __name__ == "__main__":
244-
# _generate_classes_props_json.py generate \
249+
# Usage:
250+
# generate_classes_properties.py generate \
245251
# --data_partition "val" \
246252
# --model_ckpt_path "model/ckpt/path" \
247253
# --model_config_file_path "model/config/file/path" \

0 commit comments

Comments
 (0)