Skip to content

Commit 9fa9f46

Browse files
authored
Create prediction_model.py
1 parent 1d91ee9 commit 9fa9f46

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

algorithms/prediction_model.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import numpy as np
2+
import pandas as pd
3+
from sklearn.ensemble import GradientBoostingRegressor
4+
from sklearn.model_selection import train_test_split, GridSearchCV
5+
import joblib
6+
import logging
7+
import yaml
8+
9+
class DemandPredictionModel:
10+
def __init__(self, config_path='config.yaml'):
11+
self.model = None
12+
self.load_config(config_path)
13+
self.load_model(self.model_path)
14+
15+
def load_config(self, config_path):
16+
"""Load configuration from a YAML file."""
17+
with open(config_path, 'r') as file:
18+
config = yaml.safe_load(file)
19+
self.model_path = config.get('demand_model_path', 'demand_prediction_model.pkl')
20+
self.test_size = config.get('test_size', 0.2)
21+
22+
def load_model(self, model_path):
23+
"""Load a pre-trained machine learning model."""
24+
try:
25+
self.model = joblib.load(model_path)
26+
logging.info(f'Model loaded from {model_path}')
27+
except FileNotFoundError:
28+
logging.warning(f'Model file not found at {model_path}. Please train the model first.')
29+
30+
def train_model(self, historical_data):
31+
"""Train a Gradient Boosting model on historical demand data."""
32+
# Prepare the data
33+
X = historical_data[['market_price', 'current_supply', 'other_factors']]
34+
y = historical_data['demand']
35+
36+
# Split the data into training and testing sets
37+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=self.test_size, random_state=42)
38+
39+
# Hyperparameter tuning using Grid Search
40+
param_grid = {
41+
'n_estimators': [100, 200],
42+
'learning_rate': [0.01, 0.1, 0.2],
43+
'max_depth': [3, 5, 7]
44+
}
45+
grid_search = GridSearchCV(GradientBoostingRegressor(), param_grid, cv=5)
46+
grid_search.fit(X_train, y_train)
47+
48+
# Best model from grid search
49+
self.model = grid_search.best_estimator_
50+
logging.info(f'Best parameters: {grid_search.best_params_}')
51+
52+
# Save the model
53+
joblib.dump(self.model, self.model_path)
54+
logging.info(f'Model trained and saved to {self.model_path}')
55+
56+
# Evaluate the model
57+
score = self.model.score(X_test, y_test)
58+
logging.info(f'Model trained with R^2 score: {score}')
59+
60+
def predict_demand(self, market_price, current_supply, other_factors):
61+
"""Predict demand based on market conditions."""
62+
if self.model is None:
63+
raise Exception("Model is not trained or loaded.")
64+
65+
# Prepare input for prediction
66+
input_data = np.array([[market_price, current_supply, other_factors]])
67+
predicted_demand = self.model.predict(input_data)
68+
logging.info(f'Predicted demand: {predicted_demand[0]}')
69+
return predicted_demand[0]
70+
71+
# Example usage
72+
if __name__ == "__main__":
73+
# Configure logging
74+
logging.basicConfig(level=logging.INFO)
75+
76+
# Load historical data (this should be replaced with actual data)
77+
historical_data = pd.DataFrame({
78+
'market_price': [100, 105, 110, 95, 90],
79+
'current_supply': [1000, 1100, 1200, 900, 800],
80+
'other_factors': [1, 2, 1, 2, 1],
81+
'demand': [950, 1150, 1250, 850, 750]
82+
})
83+
84+
# Initialize the demand prediction model
85+
demand_model = DemandPredictionModel()
86+
87+
# Train the model
88+
demand_model.train_model(historical_data)
89+
90+
# Predict demand based on current market conditions
91+
market_price = 105
92+
current_supply = 1000
93+
other_factors = 1
94+
predicted_demand = demand_model.predict_demand(market_price, current_supply, other_factors)
95+
print(f'Predicted Demand: {predicted_demand}')

0 commit comments

Comments
 (0)