-
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathjax_nutritional_content_prediction_custom_transform.py
More file actions
108 lines (87 loc) · 3.56 KB
/
jax_nutritional_content_prediction_custom_transform.py
File metadata and controls
108 lines (87 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# jax_nutritional_content_prediction_custom_transform.py
# Import necessary libraries
import jax
import jax.numpy as jnp
from jax import grad, jit, random
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu
import numpy as np
# Simulate data
np.random.seed(42)
X_train = np.random.rand(100, 8) # 100 samples, 8 features
y_train = np.random.rand(100, 4) # 100 samples, 4 targets (calories, proteins, carbs, fats)
X_test = np.random.rand(20, 8) # 20 samples, 8 features
y_test = np.random.rand(20, 4) # 20 samples, 4 targets
# Define the neural network architecture
def create_model(input_shape):
return stax.serial(
Dense(128), Relu,
Dense(64), Relu,
Dense(4) # Output layer for 4 targets
)
# Create the model
init_fun, apply_fun = create_model(X_train.shape[1])
# Initialize model parameters
key = random.PRNGKey(42)
_, init_params = init_fun(key, (-1, X_train.shape[1]))
# Define the loss function
def loss_fn(params, inputs, targets):
predictions = apply_fun(params, inputs)
return jnp.mean((predictions - targets) ** 2)
# Define the update function
@jit
def update(params, inputs, targets, learning_rate):
grads = grad(loss_fn)(params, inputs, targets)
return jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
# Custom transformation for batch normalization
def custom_batch_norm(x, mean, var, beta, gamma, eps=1e-5):
return (x - mean) / jnp.sqrt(var + eps) * gamma + beta
# Training settings
num_epochs = 100
batch_size = 32
learning_rate = 0.001
# Train the model
for epoch in range(num_epochs):
# Shuffle the training data
perm = np.random.permutation(len(X_train))
X_train = X_train[perm]
y_train = y_train[perm]
# Mini-batch training
epoch_loss = 0.0
num_batches = 0
for i in range(0, len(X_train), batch_size):
X_batch = X_train[i:i+batch_size]
y_batch = y_train[i:i+batch_size]
# Apply custom batch normalization
mean = jnp.mean(X_batch, axis=0)
var = jnp.var(X_batch, axis=0)
beta = jnp.zeros_like(mean)
gamma = jnp.ones_like(var)
X_batch = custom_batch_norm(X_batch, mean, var, beta, gamma)
# Update parameters
init_params = update(init_params, X_batch, y_batch, learning_rate)
batch_loss = loss_fn(init_params, X_batch, y_batch)
epoch_loss += batch_loss
num_batches += 1
# Calculate average loss for the epoch
epoch_loss /= num_batches
print(f"Epoch {epoch + 1}, Loss: {epoch_loss:.4f}")
# Evaluate the model
test_predictions = apply_fun(init_params, X_test)
test_loss = jnp.mean((test_predictions - y_test) ** 2)
print(f"Test Loss: {test_loss:.4f}")
# Define the inference function
@jit
def predict(params, inputs):
return apply_fun(params, inputs)
# Example usage
new_input = np.array([[5, 3, 2, 1, 4, 2, 0, 1]]) # Replace with actual input features
predicted_nutrition = predict(init_params, new_input)
print(f"Predicted Nutrition: {predicted_nutrition}")
# Possible Errors and Solutions:
# 1. ValueError: If the input data contains NaN or infinite values, it may cause errors in training.
# Solution: Ensure the input data is cleaned and contains no NaN or infinite values.
# 2. Dimension Mismatch: If the dimensions of input features or targets do not match the expected dimensions, an error will occur.
# Solution: Verify the dimensions of your input data and targets before training.
# 3. Convergence Issues: If the learning rate is too high, the model may not converge.
# Solution: Decrease the learning rate and try again.