I am trying to use InfoBatch for an image retrieval task. The loss I calculated is a tensor value (tensor(1.5947, device='cuda:0', grad_fn=)), which has no .shape attribute, so when I call loss = train_dataset.update(loss), I get an error. How can I solve this problem?

