@@ -172,13 +172,18 @@ def generate_props(
172172 class_names = self .load_class_labels (classes_file )
173173 num_classes = len (class_names )
174174 metrics_obj_dict : dict [str , torchmetrics .Metric ] = {
175- "cm" : MultilabelConfusionMatrix (num_labels = num_classes ),
176- "f1" : MultilabelF1Score (num_labels = num_classes , average = None ),
175+ "cm" : MultilabelConfusionMatrix (num_labels = num_classes ).to (
176+ device = model .device
177+ ),
178+ "f1" : MultilabelF1Score (num_labels = num_classes , average = None ).to (
179+ device = model .device
180+ ),
177181 }
178182
179183 for batch_idx , batch in enumerate (data_loader ):
180184 data = model ._process_batch (batch , batch_idx = batch_idx )
181- labels = data ["labels" ]
185+ labels = data ["labels" ].to (device = model .device )
186+ data ["features" ][0 ].to (device = model .device )
182187 model_output = model (data , ** data .get ("model_kwargs" , {}))
183188 preds , targets = model ._get_prediction_and_labels (
184189 data , labels , model_output
@@ -241,7 +246,8 @@ def generate(
241246
242247
243248if __name__ == "__main__" :
244- # _generate_classes_props_json.py generate \
249+ # Usage:
250+ # generate_classes_properties.py generate \
245251 # --data_partition "val" \
246252 # --model_ckpt_path "model/ckpt/path" \
247253 # --model_config_file_path "model/config/file/path" \
0 commit comments