NN output within a numba jitted function #8099
-
Hello, I have a jitted function within which I need to use the output of a neural network (trained using PyTorch Lightning). The pseudo code will make this clearer: while True:
x = sample_from_model() # ← numpy type, hence compatible with numba
out = NN(torch.Tensor(x)) # ← incompatible with numba Is there a way to circumvent this problem? First thing that comes to mind is to manually extract the weights and compute the forward pass. Thanks in advance, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi Petar, I'm not that familiar with numba, but if it runs with numpy types, you should be able to do this via ONNX export. you should be able to simply get this with Note: This dumps it to disk and you can use the onnx runtime for prediction then. |
Beta Was this translation helpful? Give feedback.
Hi Petar,
I'm not that familiar with numba, but if it runs with numpy types, you should be able to do this via ONNX export.
you should be able to simply get this with
my_lightning_model.to_onnx()
.Note: This dumps it to disk and you can use the onnx runtime for prediction then.