diff --git a/examples/imagenet_logits.py b/examples/imagenet_logits.py index ef2cbf19..fe05cb9a 100644 --- a/examples/imagenet_logits.py +++ b/examples/imagenet_logits.py @@ -62,7 +62,7 @@ def main(): # Make predictions output = model(input) # size(1, 1000) max, argmax = output.data.squeeze().max(0) - class_id = argmax[0] + class_id = argmax.item() class_key = class_id_to_key[class_id] classname = key_to_classname[class_key]