## Proposed API ```julia (y, st), train_state = train_state(x) ``` ## Benefits 1. This makes it very simple to run inference. Simply call `train_state(x)` and we take care of `Lux.testmode(st)`, etc. 2. Allows us to cache the compiled inference function once #969 lands