section 02: I am unable to understand plot_decision_boundry()
function
#567
-
import numpy as np
def plot_decision_boundry(model, X, y):
model.to("cpu")
X, y = X.to("cpu"), y.to("cpu")
x_min, x_max = X[:, 0].min(), X[:, 0].max()
y_min, y_max = X[:, 1].min(), X[:, 1].max()
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101),
np.linspace(y_min, y_max, 101))
# Make features
X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()
# Make pred
model.eval()
with torch.inference_mode():
y_logits = model(X_to_pred_on)
# Test for multi-class or binary and adjust logits to prediction labels
if len(torch.unique(y)) > 2:
y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1) # mutli-class
else:
y_pred = torch.round(torch.sigmoid(y_logits)) # binary
# Reshape preds and plot
y_pred = y_pred.reshape(xx.shape).detach().numpy()
plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max()) Can Anyone please help me with this code? |
Beta Was this translation helpful? Give feedback.
Answered by
Perian-Yan
Jul 24, 2023
Replies: 1 comment
-
Hi, btw, you can always print the intermediate result to see what's going on. You can refer to Hope this helps 😄 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
AS1100K
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
the key is to get the coordinates of each grid point (here the number is 101*101), and treat them as a "new dataset"
X_pred_on
. Then, predict the label (0 or 1) for these grid points. Finallyplt.contourf
will plot with filled color and also the boundary.btw, you can always print the intermediate result to see what's going on.
You can refer to
meshgrid
ravel
column_stack
contourf
Hope this helps 😄