Skip to content

Commit eeb62d1

Browse files
author
Stefan Dumitrescu
committed
Added the torch.no_grad() directive so the model won't calculate gradients at inference time.
This should speed up the process as well as halve the amount of RAM used. Seems to be stable at around 4-5GB now.
1 parent 8b81370 commit eeb62d1

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

fred/inference/predict.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,19 @@ def predict(image_file):
4747
model = NNet()
4848
model.load_state_dict(torch.load(model_path, map_location='cpu'))
4949
model.eval()
50-
51-
pilim = Image.open(image_file).convert('L').convert('RGB')
52-
pilim = preprocess_pilim(pilim)
53-
input_array = prepare_for_input(pilim, flip_lr=False)
54-
55-
lr_input_array = prepare_for_input(pilim, flip_lr=True)
56-
try:
57-
out_array = get_output(model(get_tensor(input_array)))
58-
except:
59-
exit(2)
60-
61-
lr_out_array = np.fliplr(get_output(model(get_tensor(lr_input_array))))
50+
51+
with torch.no_grad():
52+
pilim = Image.open(image_file).convert('L').convert('RGB')
53+
pilim = preprocess_pilim(pilim)
54+
input_array = prepare_for_input(pilim, flip_lr=False)
55+
56+
lr_input_array = prepare_for_input(pilim, flip_lr=True)
57+
try:
58+
out_array = get_output(model(get_tensor(input_array)))
59+
except:
60+
exit(2)
61+
62+
lr_out_array = np.fliplr(get_output(model(get_tensor(lr_input_array))))
6263

6364
out_array = (out_array + lr_out_array) / 2
6465
out_array = threshold_output(out_array, 0.5)

0 commit comments

Comments
 (0)