Skip to content
Discussion options

You must be logged in to vote

I wonder if this is related to this issue.

I find that just by the Python for-loop for jitted _body can improve the performance (2.18s)

@jit
def _body(_, val):
  # Forward pass [NX1] · [1X1] = [NX1]
  W, b = val

  # Loss
  error = y_train - (x_train @ W + b)
  loss = (error.T @ error) / N

  # Backpropagation
  dW = -(2/N) * (x_train.T @ error)
  db = -(2/N) * np.sum(error)

  # Update weights
  W += -learning_rate * dW
  b += -learning_rate * db

  return (W, b)

for _ in range(num_epochs):
  W, b = _body(_, (W, b))

Replies: 6 comments 6 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
3 replies
@jakevdp
Comment options

@gavincyi
Comment options

@jakevdp
Comment options

Comment options

You must be logged in to vote
2 replies
@jakevdp
Comment options

@gavincyi
Comment options

Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Answer selected by gavincyi
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants