mx.eval(l, a) array without primitive error #2914
aimaregana
started this conversation in
General
Replies: 1 comment 3 replies
-
|
When you compile it's important to capture any implicit inputs and outputs to the function: from functools import partial state = [model, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
(loss, acc), grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss, accAnd below: loss, acc = train_step(x, y)
mx.eval(loss, acc, state) |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello everyone,
I'm trying to go from PyTorch to MLX due I'm using Mac mini M4 Pro computer and I know the speed improvement using MLX.
I tested speed comparison "same code" for PyTorch and mlx and I got:
about 300 img/s with PyTorch
about 850 img/s with MLX.
My main problem is the mx.eval(l, a) that returns an error of "attempting to eval without primitive" after first epoch.
As you can see there are 877 img/s at first epoch and then the error comes.
Anyone could help me hot to solve this problem?
Best regards in advance!!!!
If I remove @mx.compile in train_step and eval_step everything works but the speed goes down to 130 img/s.
My code:
kaggle MLX 2.py
Beta Was this translation helpful? Give feedback.
All reactions