Skip to content

Commit 84d5d78

Browse files
authored
Add LSTM (#247)
1 parent f697081 commit 84d5d78

File tree

1 file changed

+349
-0
lines changed

1 file changed

+349
-0
lines changed
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
# LSTM Time Series Prediction in R
2+
# Long Short-Term Memory (LSTM) Neural Network for Time Series Forecasting
3+
#
4+
# Required libraries: keras, tensorflow, (optional) tidyr or reshape2, ggplot2
5+
# Install with: install.packages("keras"); install.packages("tensorflow")
6+
# Optionally: install.packages(c("tidyr","reshape2","ggplot2"))
7+
# Then run: keras::install_keras()
8+
9+
suppressPackageStartupMessages({
10+
library(keras)
11+
library(tensorflow)
12+
})
13+
14+
#' Create sequences for LSTM training
15+
#' @param data: Numeric vector or matrix of time series data
16+
#' @param seq_length: Length of input sequences
17+
#' @return: List containing X (input sequences) and y (target values)
18+
create_sequences <- function(data, seq_length) {
19+
n <- length(data)
20+
21+
# Initialize lists to store sequences
22+
X <- list()
23+
y <- list()
24+
25+
# Create sequences
26+
for (i in 1:(n - seq_length)) {
27+
X[[i]] <- data[i:(i + seq_length - 1)]
28+
y[[i]] <- data[i + seq_length]
29+
}
30+
31+
# Convert lists to arrays
32+
X <- array(unlist(X), dim = c(length(X), seq_length, 1))
33+
y <- array(unlist(y), dim = c(length(y), 1))
34+
35+
return(list(X = X, y = y))
36+
}
37+
38+
#' Normalize data to [0, 1] range
39+
#' @param data: Numeric vector
40+
#' @return: List with normalized data, min, and max values
41+
normalize_data <- function(data) {
42+
min_val <- min(data)
43+
max_val <- max(data)
44+
normalized <- (data - min_val) / (max_val - min_val)
45+
46+
return(list(
47+
data = normalized,
48+
min = min_val,
49+
max = max_val
50+
))
51+
}
52+
53+
#' Inverse normalize data back to original scale
54+
#' @param data: Normalized data
55+
#' @param min_val: Original minimum value
56+
#' @param max_val: Original maximum value
57+
#' @return: Data in original scale
58+
denormalize_data <- function(data, min_val, max_val) {
59+
return(data * (max_val - min_val) + min_val)
60+
}
61+
62+
#' Build LSTM model for time series prediction
63+
#' @param seq_length: Length of input sequences
64+
#' @param lstm_units: Number of LSTM units (neurons)
65+
#' @param dropout_rate: Dropout rate for regularization (0 to 1)
66+
#' @param learning_rate: Learning rate for optimizer
67+
#' @return: Compiled Keras model
68+
build_lstm_model <- function(seq_length, lstm_units = 50,
69+
dropout_rate = 0.2, learning_rate = 0.001) {
70+
71+
model <- keras_model_sequential() %>%
72+
layer_lstm(units = lstm_units,
73+
activation = 'tanh',
74+
input_shape = c(seq_length, 1),
75+
return_sequences = FALSE) %>%
76+
layer_dropout(rate = dropout_rate) %>%
77+
layer_dense(units = 1) # Output layer for regression
78+
79+
# Compile model
80+
model %>% compile(
81+
optimizer = optimizer_adam(learning_rate = learning_rate),
82+
loss = 'mean_squared_error',
83+
metrics = c('mae')
84+
)
85+
86+
return(model)
87+
}
88+
89+
#' Calculate evaluation metrics
90+
#' @param actual: Actual values
91+
#' @param predicted: Predicted values
92+
#' @return: List of evaluation metrics
93+
calculate_metrics <- function(actual, predicted) {
94+
mse <- mean((actual - predicted)^2)
95+
rmse <- sqrt(mse)
96+
mae <- mean(abs(actual - predicted))
97+
98+
# R-squared
99+
ss_res <- sum((actual - predicted)^2)
100+
ss_tot <- sum((actual - mean(actual))^2)
101+
r_squared <- 1 - (ss_res / ss_tot)
102+
103+
return(list(
104+
MSE = mse,
105+
RMSE = rmse,
106+
MAE = mae,
107+
R_squared = r_squared
108+
))
109+
}
110+
111+
# ========== Main Example: Sine Wave Prediction ==========
112+
113+
cat("========== LSTM Time Series Prediction Example ==========\n\n")
114+
cat("Generating synthetic sine wave data...\n")
115+
116+
# Set random seed for reproducibility
117+
set.seed(42)
118+
tensorflow::tf$random$set_seed(42)
119+
120+
# Generate sine wave data (500 points)
121+
time_points <- seq(0, 20, length.out = 500)
122+
data <- sin(time_points)
123+
124+
cat(sprintf("Generated %d data points\n\n", length(data)))
125+
126+
# Normalize data
127+
cat("Normalizing data to [0, 1] range...\n")
128+
normalized <- normalize_data(data)
129+
data_norm <- normalized$data
130+
131+
# Create sequences for LSTM
132+
seq_length <- 20
133+
cat(sprintf("Creating sequences with length: %d\n", seq_length))
134+
sequences <- create_sequences(data_norm, seq_length)
135+
136+
X <- sequences$X
137+
y <- sequences$y
138+
139+
cat(sprintf("Total sequences created: %d\n\n", dim(X)[1]))
140+
141+
# Split into train and test sets (80-20 split)
142+
train_size <- floor(0.8 * dim(X)[1])
143+
144+
# IMPORTANT: For time series, use sequential split (preserve temporal order)
145+
train_indices <- 1:train_size
146+
test_indices <- (train_size + 1):dim(X)[1]
147+
148+
X_train <- X[train_indices, , , drop = FALSE]
149+
y_train <- y[train_indices, , drop = FALSE]
150+
X_test <- X[test_indices, , , drop = FALSE]
151+
y_test <- y[test_indices, , drop = FALSE]
152+
153+
cat(sprintf("Training samples: %d\n", dim(X_train)[1]))
154+
cat(sprintf("Test samples: %d\n\n", dim(X_test)[1]))
155+
156+
# Build LSTM model
157+
cat("Building LSTM model...\n")
158+
model <- build_lstm_model(
159+
seq_length = seq_length,
160+
lstm_units = 50,
161+
dropout_rate = 0.2,
162+
learning_rate = 0.001
163+
)
164+
165+
cat("\nModel Architecture:\n")
166+
print(summary(model))
167+
168+
# Train the model
169+
cat("\n========== Training LSTM Model ==========\n\n")
170+
171+
history <- model %>% fit(
172+
X_train, y_train,
173+
epochs = 50,
174+
batch_size = 16,
175+
validation_split = 0.1,
176+
verbose = 1
177+
)
178+
179+
# Plot training history
180+
if (requireNamespace("ggplot2", quietly = TRUE)) {
181+
cat("\nPlotting training history...\n")
182+
plot(history)
183+
}
184+
185+
# Evaluate on test data
186+
cat("\n========== Model Evaluation ==========\n\n")
187+
evaluation <- model %>% evaluate(X_test, y_test, verbose = 0)
188+
cat(sprintf("Test Loss (MSE): %.6f\n", evaluation[[1]]))
189+
cat(sprintf("Test MAE: %.6f\n\n", evaluation[[2]]))
190+
191+
# Make predictions
192+
cat("Making predictions on test set...\n")
193+
y_pred <- model %>% predict(X_test, verbose = 0)
194+
195+
# Denormalize predictions and actual values
196+
y_test_orig <- denormalize_data(y_test, normalized$min, normalized$max)
197+
y_pred_orig <- denormalize_data(y_pred, normalized$min, normalized$max)
198+
199+
# Calculate metrics on original scale
200+
metrics <- calculate_metrics(y_test_orig, y_pred_orig)
201+
202+
cat("\n========== Performance Metrics (Original Scale) ==========\n\n")
203+
cat(sprintf("Mean Squared Error (MSE): %.6f\n", metrics$MSE))
204+
cat(sprintf("Root Mean Squared Error (RMSE): %.6f\n", metrics$RMSE))
205+
cat(sprintf("Mean Absolute Error (MAE): %.6f\n", metrics$MAE))
206+
cat(sprintf("R-squared: %.6f\n\n", metrics$R_squared))
207+
208+
# Display sample predictions
209+
cat("========== Sample Predictions ==========\n\n")
210+
cat(sprintf("%-15s %-15s %-15s\n", "Actual", "Predicted", "Error"))
211+
cat(strrep("-", 50), "\n")
212+
213+
n_samples <- min(10, length(y_test_orig))
214+
for (i in 1:n_samples) {
215+
error <- abs(y_test_orig[i] - y_pred_orig[i])
216+
cat(sprintf("%-15.6f %-15.6f %-15.6f\n",
217+
y_test_orig[i], y_pred_orig[i], error))
218+
}
219+
220+
# ========== Visualization ==========
221+
222+
if (requireNamespace("ggplot2", quietly = TRUE)) {
223+
library(ggplot2)
224+
225+
cat("\n========== Creating Prediction Plot ==========\n\n")
226+
227+
# Create dataframe for plotting
228+
plot_data <- data.frame(
229+
Index = 1:length(y_test_orig),
230+
Actual = as.vector(y_test_orig),
231+
Predicted = as.vector(y_pred_orig)
232+
)
233+
234+
# Reshape for ggplot: prefer tidyr::pivot_longer if available, fallback to reshape2::melt
235+
if (requireNamespace("tidyr", quietly = TRUE)) {
236+
plot_data_long <- tidyr::pivot_longer(plot_data,
237+
cols = -Index,
238+
names_to = "variable",
239+
values_to = "value")
240+
} else if (requireNamespace("reshape2", quietly = TRUE)) {
241+
plot_data_long <- reshape2::melt(plot_data, id.vars = "Index")
242+
# Ensure consistent column names with pivot_longer
243+
names(plot_data_long) <- c("Index", "variable", "value")
244+
} else {
245+
stop("Please install 'tidyr' or 'reshape2' to create the plot (install.packages('tidyr')).")
246+
}
247+
248+
# Create plot (use linewidth instead of size for modern ggplot2)
249+
p <- ggplot(plot_data_long, aes(x = Index, y = value, color = variable)) +
250+
geom_line(linewidth = 1) +
251+
geom_point(alpha = 0.5) +
252+
scale_color_manual(values = c("Actual" = "blue", "Predicted" = "red")) +
253+
labs(
254+
title = "LSTM Time Series Predictions",
255+
subtitle = sprintf("Test Set: %d samples (RMSE: %.4f)",
256+
length(y_test_orig), metrics$RMSE),
257+
x = "Sample Index",
258+
y = "Value",
259+
color = "Series"
260+
) +
261+
theme_minimal() +
262+
theme(
263+
plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
264+
plot.subtitle = element_text(hjust = 0.5, size = 10),
265+
legend.position = "bottom"
266+
)
267+
268+
print(p)
269+
270+
cat("Plot created successfully!\n\n")
271+
}
272+
273+
# ========== Additional Example: Multi-step Prediction ==========
274+
275+
cat("========== Multi-Step Ahead Prediction ==========\n\n")
276+
277+
#' Make multi-step predictions
278+
#' @param model: Trained LSTM model
279+
#' @param initial_seq: Initial sequence to start prediction
280+
#' @param n_steps: Number of steps to predict ahead
281+
#' @param min_val: Min value for denormalization
282+
#' @param max_val: Max value for denormalization
283+
#' @return: Vector of predictions
284+
predict_multi_step <- function(model, initial_seq, n_steps, min_val, max_val) {
285+
predictions <- numeric(n_steps)
286+
current_seq <- initial_seq
287+
288+
for (i in 1:n_steps) {
289+
# Predict next value
290+
pred <- model %>% predict(current_seq, verbose = 0)
291+
predictions[i] <- denormalize_data(pred, min_val, max_val)
292+
293+
# Update sequence: remove first value, add prediction
294+
current_seq <- array(
295+
c(current_seq[1, 2:seq_length, 1], pred),
296+
dim = c(1, seq_length, 1)
297+
)
298+
}
299+
300+
return(predictions)
301+
}
302+
303+
# Use first test sequence for multi-step prediction
304+
initial_sequence <- X_test[1, , , drop = FALSE]
305+
n_future_steps <- 20
306+
307+
cat(sprintf("Predicting %d steps ahead...\n", n_future_steps))
308+
future_predictions <- predict_multi_step(
309+
model,
310+
initial_sequence,
311+
n_future_steps,
312+
normalized$min,
313+
normalized$max
314+
)
315+
316+
cat("\nMulti-step predictions:\n")
317+
for (i in 1:min(10, n_future_steps)) {
318+
cat(sprintf("Step %2d: %.6f\n", i, future_predictions[i]))
319+
}
320+
321+
# ========== Tips and Best Practices ==========
322+
323+
cat("\n========== LSTM Best Practices ==========\n\n")
324+
cat("1. Data Preprocessing:\n")
325+
cat(" - Normalize/standardize input data\n")
326+
cat(" - Handle missing values appropriately\n")
327+
cat(" - Consider detrending for non-stationary series\n\n")
328+
329+
cat("2. Model Architecture:\n")
330+
cat(" - Start with 1-2 LSTM layers\n")
331+
cat(" - Use dropout for regularization (0.2-0.5)\n")
332+
cat(" - Consider bidirectional LSTM for complex patterns\n\n")
333+
334+
cat("3. Training:\n")
335+
cat(" - Use appropriate batch size (16-128)\n")
336+
cat(" - Monitor validation loss to prevent overfitting\n")
337+
cat(" - Use early stopping and model checkpointing\n\n")
338+
339+
cat("4. Hyperparameter Tuning:\n")
340+
cat(" - Sequence length: depends on temporal dependencies\n")
341+
cat(" - LSTM units: 32-256 typically works well\n")
342+
cat(" - Learning rate: 0.001-0.01 for Adam optimizer\n\n")
343+
344+
cat("5. Evaluation:\n")
345+
cat(" - Use walk-forward validation for time series\n")
346+
cat(" - Check residuals for patterns\n")
347+
cat(" - Compare with baseline models (ARIMA, simple average)\n\n")
348+
349+
cat("========== Example Complete ==========\n")

0 commit comments

Comments
 (0)