|
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | +from sklearn.ensemble import GradientBoostingRegressor |
| 4 | +from sklearn.model_selection import train_test_split, GridSearchCV |
| 5 | +import joblib |
| 6 | +import logging |
| 7 | +import yaml |
| 8 | + |
| 9 | +class DemandPredictionModel: |
| 10 | + def __init__(self, config_path='config.yaml'): |
| 11 | + self.model = None |
| 12 | + self.load_config(config_path) |
| 13 | + self.load_model(self.model_path) |
| 14 | + |
| 15 | + def load_config(self, config_path): |
| 16 | + """Load configuration from a YAML file.""" |
| 17 | + with open(config_path, 'r') as file: |
| 18 | + config = yaml.safe_load(file) |
| 19 | + self.model_path = config.get('demand_model_path', 'demand_prediction_model.pkl') |
| 20 | + self.test_size = config.get('test_size', 0.2) |
| 21 | + |
| 22 | + def load_model(self, model_path): |
| 23 | + """Load a pre-trained machine learning model.""" |
| 24 | + try: |
| 25 | + self.model = joblib.load(model_path) |
| 26 | + logging.info(f'Model loaded from {model_path}') |
| 27 | + except FileNotFoundError: |
| 28 | + logging.warning(f'Model file not found at {model_path}. Please train the model first.') |
| 29 | + |
| 30 | + def train_model(self, historical_data): |
| 31 | + """Train a Gradient Boosting model on historical demand data.""" |
| 32 | + # Prepare the data |
| 33 | + X = historical_data[['market_price', 'current_supply', 'other_factors']] |
| 34 | + y = historical_data['demand'] |
| 35 | + |
| 36 | + # Split the data into training and testing sets |
| 37 | + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=self.test_size, random_state=42) |
| 38 | + |
| 39 | + # Hyperparameter tuning using Grid Search |
| 40 | + param_grid = { |
| 41 | + 'n_estimators': [100, 200], |
| 42 | + 'learning_rate': [0.01, 0.1, 0.2], |
| 43 | + 'max_depth': [3, 5, 7] |
| 44 | + } |
| 45 | + grid_search = GridSearchCV(GradientBoostingRegressor(), param_grid, cv=5) |
| 46 | + grid_search.fit(X_train, y_train) |
| 47 | + |
| 48 | + # Best model from grid search |
| 49 | + self.model = grid_search.best_estimator_ |
| 50 | + logging.info(f'Best parameters: {grid_search.best_params_}') |
| 51 | + |
| 52 | + # Save the model |
| 53 | + joblib.dump(self.model, self.model_path) |
| 54 | + logging.info(f'Model trained and saved to {self.model_path}') |
| 55 | + |
| 56 | + # Evaluate the model |
| 57 | + score = self.model.score(X_test, y_test) |
| 58 | + logging.info(f'Model trained with R^2 score: {score}') |
| 59 | + |
| 60 | + def predict_demand(self, market_price, current_supply, other_factors): |
| 61 | + """Predict demand based on market conditions.""" |
| 62 | + if self.model is None: |
| 63 | + raise Exception("Model is not trained or loaded.") |
| 64 | + |
| 65 | + # Prepare input for prediction |
| 66 | + input_data = np.array([[market_price, current_supply, other_factors]]) |
| 67 | + predicted_demand = self.model.predict(input_data) |
| 68 | + logging.info(f'Predicted demand: {predicted_demand[0]}') |
| 69 | + return predicted_demand[0] |
| 70 | + |
| 71 | +# Example usage |
| 72 | +if __name__ == "__main__": |
| 73 | + # Configure logging |
| 74 | + logging.basicConfig(level=logging.INFO) |
| 75 | + |
| 76 | + # Load historical data (this should be replaced with actual data) |
| 77 | + historical_data = pd.DataFrame({ |
| 78 | + 'market_price': [100, 105, 110, 95, 90], |
| 79 | + 'current_supply': [1000, 1100, 1200, 900, 800], |
| 80 | + 'other_factors': [1, 2, 1, 2, 1], |
| 81 | + 'demand': [950, 1150, 1250, 850, 750] |
| 82 | + }) |
| 83 | + |
| 84 | + # Initialize the demand prediction model |
| 85 | + demand_model = DemandPredictionModel() |
| 86 | + |
| 87 | + # Train the model |
| 88 | + demand_model.train_model(historical_data) |
| 89 | + |
| 90 | + # Predict demand based on current market conditions |
| 91 | + market_price = 105 |
| 92 | + current_supply = 1000 |
| 93 | + other_factors = 1 |
| 94 | + predicted_demand = demand_model.predict_demand(market_price, current_supply, other_factors) |
| 95 | + print(f'Predicted Demand: {predicted_demand}') |
0 commit comments