+{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Ch4. 데이터 모델링.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"}},"cells":[{"cell_type":"code","metadata":{"id":"HszE6YFCHI7e","executionInfo":{"status":"ok","timestamp":1606058711003,"user_tz":-540,"elapsed":3912,"user":{"displayName":"안성진","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiCjgkN_MvtrSUHRuFvstrWm6fhi5cf7CKd2UHYAw=s64","userId":"00266029492778998652"}},"outputId":"1b05c673-7fae-41ec-978c-e1ba06d385d4","colab":{"base_uri":"https://localhost:8080/"}},"source":["import torch\n","import os\n","import numpy as np\n","import pandas as pd\n","from tqdm import tqdm\n","import seaborn as sns\n","from pylab import rcParams\n","import matplotlib.pyplot as plt\n","from matplotlib import rc\n","from sklearn.preprocessing import MinMaxScaler\n","from pandas.plotting import register_matplotlib_converters\n","from torch import nn, optim\n","\n","%matplotlib inline\n","%config InlineBackend.figure_format='retina'\n","\n","sns.set(style='whitegrid', palette='muted', font_scale=1.2)\n","HAPPY_COLORS_PALETTE = [\"#01BEFE\", \"#FFDD00\", \"#FF7D00\", \"#FF006D\", \"#93D30C\", \"#8F00FF\"]\n","sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))\n","rcParams['figure.figsize'] = 14, 10\n","register_matplotlib_converters()\n","RANDOM_SEED = 42\n","np.random.seed(RANDOM_SEED)\n","torch.manual_seed(RANDOM_SEED)"],"execution_count":1,"outputs":[{"output_type":"execute_result","data":{"text/plain":["<torch._C.Generator at 0x7f805232db58>"]},"metadata":{"tags":[]},"execution_count":1}]},{"cell_type":"code","metadata":{"id":"7vOV2GR_g0zR","executionInfo":{"status":"ok","timestamp":1606058715469,"user_tz":-540,"elapsed":1990,"user":{"displayName":"안성진","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiCjgkN_MvtrSUHRuFvstrWm6fhi5cf7CKd2UHYAw=s64","userId":"00266029492778998652"}},"outputId":"b22d18ea-ce8a-4cfe-d1cb-54ff939d48b1","colab":{"base_uri":"https://localhost:8080/"}},"source":["!gdown --id 1QRKd4uiL8GKaH4DuuMsYRCS2gV0tAiwa"],"execution_count":2,"outputs":[{"output_type":"stream","text":["Downloading...\n","From: https://drive.google.com/uc?id=1QRKd4uiL8GKaH4DuuMsYRCS2gV0tAiwa\n","To: /content/korea_South_19-covid-Confirmed.csv\n","\r 0% 0.00/4.17k [00:00<?, ?B/s]\r100% 4.17k/4.17k [00:00<00:00, 7.02MB/s]\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"xLQtYiv_g13j"},"source":["df = pd.read_csv('korea_South_19-covid-Confirmed.csv')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"3PiQMi4lg3xD"},"source":["df = df.iloc[:, 4:]\n","print(df)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"C42DJBMig3vI"},"source":["df.isnull().sum().sum()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xPtl969rg3sS"},"source":["daily_cases = df.sum(axis=0)\n","daily_cases.index = pd.to_datetime(daily_cases.index)\n","daily_cases.head()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"bcmIYEqHg3p5"},"source":["plt.plot(daily_cases)\n","plt.title(\"Cumulative daily cases\");"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"izDEpb-Ag3nG"},"source":["daily_cases = daily_cases.diff().fillna(daily_cases[0]).astype(np.int64)\n","daily_cases.head()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ddrQR4nNg3kf"},"source":["plt.plot(daily_cases)\n","plt.title(\"Daily cases\");"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"D-b73djRg3jC"},"source":["daily_cases.shape"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"_t-cCffZg3fP"},"source":["test_data_size = 14\n","train_data = daily_cases[:-test_data_size]\n","test_data = daily_cases[-test_data_size:]\n","train_data.shape"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"1a6gsm3Xg3d9"},"source":["scaler = MinMaxScaler()\n","scaler = scaler.fit(np.expand_dims(train_data, axis=1))\n","train_data = scaler.transform(np.expand_dims(train_data, axis=1))\n","test_data = scaler.transform(np.expand_dims(test_data, axis=1))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"nU0ucpQZg3Zv"},"source":["def create_sequences(data, seq_length):\n"," xs = []\n"," ys = []\n"," for i in range(len(data)-seq_length-1):\n"," x = data[i:(i+seq_length)]\n"," y = data[i+seq_length]\n"," xs.append(x)\n"," ys.append(y)\n"," return np.array(xs), np.array(ys)\n","seq_length = 5\n","X_train, y_train = create_sequences(train_data, seq_length)\n","X_test, y_test = create_sequences(test_data, seq_length)\n","X_train = torch.from_numpy(X_train).float()\n","y_train = torch.from_numpy(y_train).float()\n","X_test = torch.from_numpy(X_test).float()\n","y_test = torch.from_numpy(y_test).float()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"G1s3rJ9og3XH"},"source":["# X_train.shape\n","print(y_test)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"y6xnam7Gg3UV"},"source":["X_train[:2]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ghco-_Fyg3Rs"},"source":["class CoronaVirusPredictor(nn.Module):\n"," def __init__(self, n_features, n_hidden, seq_len, n_layers=2):\n"," super(CoronaVirusPredictor, self).__init__()\n"," self.n_hidden = n_hidden\n"," self.seq_len = seq_len\n"," self.n_layers = n_layers\n"," self.lstm = nn.LSTM(\n"," input_size=n_features,\n"," hidden_size=n_hidden,\n"," num_layers=n_layers,\n"," dropout=0.5\n"," )\n"," self.linear = nn.Linear(in_features=n_hidden, out_features=1)\n"," def reset_hidden_state(self):\n"," self.hidden = (\n"," torch.zeros(self.n_layers, self.seq_len, self.n_hidden),\n"," torch.zeros(self.n_layers, self.seq_len, self.n_hidden)\n"," )\n"," def forward(self, sequences):\n"," lstm_out, self.hidden = self.lstm(\n"," sequences.view(len(sequences), self.seq_len, -1),\n"," self.hidden\n"," )\n"," last_time_step = \\\n"," lstm_out.view(self.seq_len, len(sequences), self.n_hidden)[-1]\n"," y_pred = self.linear(last_time_step)\n"," return y_pred"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"wg---IHWg3H_"},"source":["def train_model(\n"," model,\n"," train_data,\n"," train_labels,\n"," test_data=None,\n"," test_labels=None\n","):\n"," loss_fn = torch.nn.MSELoss(reduction='sum')\n"," optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)\n"," num_epochs = 60\n"," train_hist = np.zeros(num_epochs)\n"," test_hist = np.zeros(num_epochs)\n"," for t in range(num_epochs):\n"," model.reset_hidden_state()\n"," y_pred = model(X_train)\n"," loss = loss_fn(y_pred.float(), y_train)\n"," if test_data is not None:\n"," with torch.no_grad():\n"," y_test_pred = model(X_test)\n"," test_loss = loss_fn(y_test_pred.float(), y_test)\n"," test_hist[t] = test_loss.item()\n"," if t % 10 == 0:\n"," print(f'Epoch {t} train loss: {loss.item()} test loss: {test_loss.item()}')\n"," elif t % 10 == 0:\n"," print(f'Epoch {t} train loss: {loss.item()}')\n"," train_hist[t] = loss.item()\n"," optimiser.zero_grad()\n"," loss.backward()\n"," optimiser.step()\n"," return model.eval(), train_hist, test_hist"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"qKRjXR6ThBhQ"},"source":["model = CoronaVirusPredictor(\n"," n_features=1,\n"," n_hidden=512,\n"," seq_len=seq_length,\n"," n_layers=2\n",")\n","model, train_hist, test_hist = train_model(\n"," model,\n"," X_train,\n"," y_train,\n"," X_test,\n"," y_test\n",")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"LMxD3gW5hBeo"},"source":["plt.plot(train_hist, label=\"Training loss\")\n","plt.plot(test_hist, label=\"Test loss\")\n","plt.ylim((0, 5))\n","plt.legend();"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xI4sXM0dhBd-"},"source":["with torch.no_grad():\n"," test_seq = X_test[:1]\n"," preds = []\n"," for _ in range(len(X_test)):\n"," y_test_pred = model(test_seq)\n"," pred = torch.flatten(y_test_pred).item()\n"," preds.append(pred)\n"," new_seq = test_seq.numpy().flatten()\n"," new_seq = np.append(new_seq, [pred])\n"," new_seq = new_seq[1:]\n"," test_seq = torch.as_tensor(new_seq).view(1, seq_length, 1).float()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"lcWoRGFshBZl"},"source":["true_cases = scaler.inverse_transform(\n"," np.expand_dims(y_test.flatten().numpy(), axis=0)\n",").flatten()\n","predicted_cases = scaler.inverse_transform(\n"," np.expand_dims(preds, axis=0)\n",").flatten()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"0DNviPbfhBWt"},"source":["plt.plot(\n"," daily_cases.index[:len(train_data)],\n"," scaler.inverse_transform(train_data).flatten(),\n"," label='Historical Daily Cases'\n",")\n","plt.plot(\n"," daily_cases.index[len(train_data):len(train_data) + len(true_cases)],\n"," true_cases,\n"," label='Real Daily Cases'\n",")\n","plt.plot(\n"," daily_cases.index[len(train_data):len(train_data) + len(true_cases)],\n"," predicted_cases,\n"," label='Predicted Daily Cases'\n",")\n","plt.legend();"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"TFMU1QTzhBT5"},"source":["scaler = MinMaxScaler()\n","scaler = scaler.fit(np.expand_dims(daily_cases, axis=1))\n","all_data = scaler.transform(np.expand_dims(daily_cases, axis=1))\n","all_data.shape"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"u6MlVTUDhG8h"},"source":["X_all, y_all = create_sequences(all_data, seq_length)\n","X_all = torch.from_numpy(X_all).float()\n","y_all = torch.from_numpy(y_all).float()\n","model = CoronaVirusPredictor(\n"," n_features=1,\n"," n_hidden=512,\n"," seq_len=seq_length,\n"," n_layers=2\n",")\n","model, train_hist, _ = train_model(model, X_all, y_all)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"O9mgEK8WhG6C"},"source":["DAYS_TO_PREDICT = 12\n","with torch.no_grad():\n"," test_seq = X_all[:1]\n"," preds = []\n"," for _ in range(DAYS_TO_PREDICT):\n"," y_test_pred = model(test_seq)\n"," pred = torch.flatten(y_test_pred).item()\n"," preds.append(pred)\n"," new_seq = test_seq.numpy().flatten()\n"," new_seq = np.append(new_seq, [pred])\n"," new_seq = new_seq[1:]\n"," test_seq = torch.as_tensor(new_seq).view(1, seq_length, 1).float()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"XowhhUnthG1u"},"source":["predicted_cases = scaler.inverse_transform(\n"," np.expand_dims(preds, axis=0)\n",").flatten()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"q9MDLtx0hBQ9"},"source":["daily_cases.index[-1]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"yo09N_sZhKX1"},"source":["predicted_index = pd.date_range(\n"," start=daily_cases.index[-1],\n"," periods=DAYS_TO_PREDICT + 1,\n"," closed='right'\n",")\n","predicted_cases = pd.Series(\n"," data=predicted_cases,\n"," index=predicted_index\n",")\n","plt.plot(predicted_cases, label='Predicted Daily Cases')\n","plt.legend();"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NIOiVWRVhMZE"},"source":["plt.plot(daily_cases, label='Historical Daily Cases')\n","plt.plot(predicted_cases, label='Predicted Daily Cases')\n","plt.legend();"],"execution_count":null,"outputs":[]}]}
0 commit comments