Skip to content

Commit ec01786

Browse files
Merge pull request #2900 from smty2018/lstm
Spam Email Classification using LSTM
2 parents 7ab207e + 7187817 commit ec01786

File tree

3 files changed

+101858
-0
lines changed

3 files changed

+101858
-0
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"provenance": []
7+
},
8+
"kernelspec": {
9+
"name": "python3",
10+
"display_name": "Python 3"
11+
},
12+
"language_info": {
13+
"name": "python"
14+
}
15+
},
16+
"cells": [
17+
{
18+
"cell_type": "code",
19+
"source": [
20+
"import numpy as np\n",
21+
"import pandas as pd\n",
22+
"from sklearn.model_selection import train_test_split\n",
23+
"from sklearn.feature_extraction.text import CountVectorizer\n",
24+
"from tensorflow.keras.models import Sequential\n",
25+
"from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout\n",
26+
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
27+
"from tensorflow.keras.preprocessing.sequence import pad_sequences"
28+
],
29+
"metadata": {
30+
"id": "q9n8UiMR74n3"
31+
},
32+
"execution_count": 12,
33+
"outputs": []
34+
},
35+
{
36+
"cell_type": "code",
37+
"source": [
38+
"data = pd.read_csv('spam_ham_dataset.csv')\n",
39+
"X = data['text']\n",
40+
"y = data['label']\n",
41+
"y = y.map({'ham': 0, 'spam': 1})"
42+
],
43+
"metadata": {
44+
"id": "mzWr2hfq773m"
45+
},
46+
"execution_count": 13,
47+
"outputs": []
48+
},
49+
{
50+
"cell_type": "code",
51+
"source": [
52+
"X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)"
53+
],
54+
"metadata": {
55+
"id": "DUbwo5FV7-LB"
56+
},
57+
"execution_count": 14,
58+
"outputs": []
59+
},
60+
{
61+
"cell_type": "code",
62+
"source": [
63+
"tokenizer = Tokenizer()\n",
64+
"tokenizer.fit_on_texts(X_tr)\n",
65+
"X_tr_seq = tokenizer.texts_to_sequences(X_tr)\n",
66+
"X_te_seq = tokenizer.texts_to_sequences(X_te)\n",
67+
"max_seq_len = 100\n",
68+
"X_tr_pad = pad_sequences(X_tr_seq, maxlen=max_seq_len, padding='post')\n",
69+
"X_te_pad = pad_sequences(X_te_seq, maxlen=max_seq_len, padding='post')"
70+
],
71+
"metadata": {
72+
"id": "Em0mEiBR8BPb"
73+
},
74+
"execution_count": 15,
75+
"outputs": []
76+
},
77+
{
78+
"cell_type": "code",
79+
"source": [
80+
"vocab_size = len(tokenizer.word_index) + 1\n",
81+
"model = Sequential()\n",
82+
"model.add(Embedding(input_dim=vocab_size, output_dim=128, input_length=max_seq_len))\n",
83+
"model.add(LSTM(64, return_sequences=True))\n",
84+
"model.add(Dropout(0.5))\n",
85+
"model.add(LSTM(64))\n",
86+
"model.add(Dense(1, activation='sigmoid'))\n",
87+
"model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
88+
"model.fit(X_tr_pad, y_tr, epochs=10, batch_size=32, validation_split=0.2)"
89+
],
90+
"metadata": {
91+
"colab": {
92+
"base_uri": "https://localhost:8080/"
93+
},
94+
"id": "s_M5u0tM8Fkb",
95+
"outputId": "93d7c7f8-40b7-44be-b425-8f57dcc18190"
96+
},
97+
"execution_count": 16,
98+
"outputs": [
99+
{
100+
"output_type": "stream",
101+
"name": "stdout",
102+
"text": [
103+
"Epoch 1/10\n",
104+
"104/104 [==============================] - 35s 297ms/step - loss: 0.5534 - accuracy: 0.7482 - val_loss: 0.5267 - val_accuracy: 0.7705\n",
105+
"Epoch 2/10\n",
106+
"104/104 [==============================] - 28s 272ms/step - loss: 0.5402 - accuracy: 0.7582 - val_loss: 0.5110 - val_accuracy: 0.7742\n",
107+
"Epoch 3/10\n",
108+
"104/104 [==============================] - 29s 277ms/step - loss: 0.4195 - accuracy: 0.8210 - val_loss: 0.2600 - val_accuracy: 0.9118\n",
109+
"Epoch 4/10\n",
110+
"104/104 [==============================] - 28s 271ms/step - loss: 0.2672 - accuracy: 0.8951 - val_loss: 0.1177 - val_accuracy: 0.9614\n",
111+
"Epoch 5/10\n",
112+
"104/104 [==============================] - 28s 268ms/step - loss: 0.0643 - accuracy: 0.9797 - val_loss: 0.1393 - val_accuracy: 0.9650\n",
113+
"Epoch 6/10\n",
114+
"104/104 [==============================] - 28s 268ms/step - loss: 0.0444 - accuracy: 0.9879 - val_loss: 0.1399 - val_accuracy: 0.9710\n",
115+
"Epoch 7/10\n",
116+
"104/104 [==============================] - 29s 280ms/step - loss: 0.0451 - accuracy: 0.9912 - val_loss: 0.1501 - val_accuracy: 0.9674\n",
117+
"Epoch 8/10\n",
118+
"104/104 [==============================] - 28s 268ms/step - loss: 0.0311 - accuracy: 0.9946 - val_loss: 0.1582 - val_accuracy: 0.9686\n",
119+
"Epoch 9/10\n",
120+
"104/104 [==============================] - 29s 279ms/step - loss: 0.0275 - accuracy: 0.9952 - val_loss: 0.1492 - val_accuracy: 0.9710\n",
121+
"Epoch 10/10\n",
122+
"104/104 [==============================] - 29s 283ms/step - loss: 0.0249 - accuracy: 0.9955 - val_loss: 0.1553 - val_accuracy: 0.9698\n"
123+
]
124+
},
125+
{
126+
"output_type": "execute_result",
127+
"data": {
128+
"text/plain": [
129+
"<keras.callbacks.History at 0x78bd259d00a0>"
130+
]
131+
},
132+
"metadata": {},
133+
"execution_count": 16
134+
}
135+
]
136+
},
137+
{
138+
"cell_type": "code",
139+
"source": [
140+
"model.summary()"
141+
],
142+
"metadata": {
143+
"colab": {
144+
"base_uri": "https://localhost:8080/"
145+
},
146+
"id": "95K3lpbr8SkX",
147+
"outputId": "72b34217-6840-4f9e-f435-64c8861c5222"
148+
},
149+
"execution_count": 17,
150+
"outputs": [
151+
{
152+
"output_type": "stream",
153+
"name": "stdout",
154+
"text": [
155+
"Model: \"sequential_2\"\n",
156+
"_________________________________________________________________\n",
157+
" Layer (type) Output Shape Param # \n",
158+
"=================================================================\n",
159+
" embedding_2 (Embedding) (None, 100, 128) 6625664 \n",
160+
" \n",
161+
" lstm_4 (LSTM) (None, 100, 64) 49408 \n",
162+
" \n",
163+
" dropout_2 (Dropout) (None, 100, 64) 0 \n",
164+
" \n",
165+
" lstm_5 (LSTM) (None, 64) 33024 \n",
166+
" \n",
167+
" dense_2 (Dense) (None, 1) 65 \n",
168+
" \n",
169+
"=================================================================\n",
170+
"Total params: 6,708,161\n",
171+
"Trainable params: 6,708,161\n",
172+
"Non-trainable params: 0\n",
173+
"_________________________________________________________________\n"
174+
]
175+
}
176+
]
177+
},
178+
{
179+
"cell_type": "code",
180+
"execution_count": 18,
181+
"metadata": {
182+
"colab": {
183+
"base_uri": "https://localhost:8080/"
184+
},
185+
"id": "JjiVnmmi5Mm6",
186+
"outputId": "dcd0456d-87ea-454d-a363-1ad8193eb132"
187+
},
188+
"outputs": [
189+
{
190+
"output_type": "stream",
191+
"name": "stdout",
192+
"text": [
193+
"33/33 [==============================] - 1s 42ms/step - loss: 0.1547 - accuracy: 0.9700\n",
194+
"Test Loss: 0.1547\n",
195+
"Test Accuracy: 0.9700\n"
196+
]
197+
}
198+
],
199+
"source": [
200+
"loss, acc = model.evaluate(X_te_pad, y_te)\n",
201+
"print(f\"Test Loss: {loss:.4f}\")\n",
202+
"print(f\"Test Accuracy: {acc:.4f}\")\n"
203+
]
204+
},
205+
{
206+
"cell_type": "code",
207+
"source": [
208+
"email_text = input(\"Enter an email text: \")\n",
209+
"\n",
210+
"sequence = tokenizer.texts_to_sequences([email_text])\n",
211+
"padded_sequence = pad_sequences(sequence, maxlen=100, padding='post')\n",
212+
"prediction = model.predict(padded_sequence)\n",
213+
"\n",
214+
"if prediction > 0.5:\n",
215+
" print(\"Prediction: Spam\")\n",
216+
"else:\n",
217+
" print(\"Prediction: Ham\")"
218+
],
219+
"metadata": {
220+
"colab": {
221+
"base_uri": "https://localhost:8080/"
222+
},
223+
"id": "WvGAXfeW7f3h",
224+
"outputId": "084ba5c2-4e54-47af-d239-b83a6485b511"
225+
},
226+
"execution_count": 19,
227+
"outputs": [
228+
{
229+
"output_type": "stream",
230+
"name": "stdout",
231+
"text": [
232+
"Enter an email text: you won 1 mill\n",
233+
"1/1 [==============================] - 1s 855ms/step\n",
234+
"Prediction: Spam\n"
235+
]
236+
}
237+
]
238+
}
239+
]
240+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
Description:
2+
The primary goal of this project is to develop a machine learning model that can automatically classify emails as either spam or legitimate (ham) based on their content. The LSTM model(Long Short-Term Memory), recurrent neural network (RNN) architecture, is employed to capture sequential patterns in the text data, making it well-suited for natural language processing tasks like this.
3+
4+
Dataset
5+
6+
The project employs a labeled dataset of emails, containing both spam and ham samples. The dataset is preprocessed to tokenize and pad the text data before feeding it into the LSTM model. Please replace `'spam_ham_dataset.csv'` in the code with your actual dataset path.
7+
8+
Model Architecture
9+
10+
The LSTM model architecture involves the following key components:
11+
12+
- `Embedding` Layer: Converts the integer-encoded vocabulary into dense vectors of fixed size.
13+
- First `LSTM` Layer: Captures sequential patterns by returning sequences instead of a single output.
14+
- `Dropout` Layer: Helps prevent overfitting by randomly deactivating a fraction of input units during training.
15+
- Second `LSTM` Layer: Aggregates the output of the previous LSTM layer.
16+
- `Dense` Layer: Produces a single output unit with sigmoid activation for binary classification.
17+
18+
The model is trained using binary cross-entropy loss and the Adam optimizer.
19+
20+
21+
Install required packages:
22+
pip install numpy pandas tensorflow
23+
24+

0 commit comments

Comments
 (0)