Skip to content

Commit e37f183

Browse files
authored
Merge pull request #315 from itkovian/auth-ssl-mail
feat: allow authenticated + ssl mail
2 parents 573c39d + 75a3704 commit e37f183

File tree

3 files changed

+183
-60
lines changed

3 files changed

+183
-60
lines changed

lib/vsc/utils/mail.py

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,11 @@
3434
import logging
3535
import re
3636
import smtplib
37-
37+
import ssl
3838
from email.mime.multipart import MIMEMultipart
3939
from email.mime.text import MIMEText
4040
from email.mime.image import MIMEImage
4141

42-
4342
class VscMailError(Exception):
4443
"""Raised if the sending of an email fails for some reason."""
4544

@@ -68,14 +67,43 @@ def __init__(self, mail_host=None, mail_to=None, mail_from=None, mail_subject=No
6867
class VscMail(object):
6968
"""Class providing functionality to send out mail."""
7069

71-
def __init__(self, mail_host=None):
70+
def __init__(
71+
self,
72+
mail_host='',
73+
mail_port=0,
74+
smtp_auth_user=None,
75+
smtp_auth_password=None,
76+
smtp_use_starttls=False):
77+
7278
self.mail_host = mail_host
79+
self.mail_port = mail_port
80+
self.smtp_auth_user = smtp_auth_user
81+
self.smtp_auth_password = smtp_auth_password
82+
self.smtp_use_starttls = smtp_use_starttls
83+
84+
def _connect(self):
85+
"""
86+
Connect to the mail host on the given port.
87+
88+
If provided, use authentication and TLS.
89+
"""
90+
logging.debug("Using mail host %s, mail port %d", self.mail_host, self.mail_port)
91+
s = smtplib.SMTP(host=self.mail_host, port=self.mail_port)
7392

74-
def _send(self,
75-
mail_from,
76-
mail_to,
77-
mail_subject,
78-
msg):
93+
if self.smtp_use_starttls:
94+
context = ssl.create_default_context()
95+
s.starttls(context=context)
96+
logging.debug("Started TLS connection")
97+
98+
if self.smtp_auth_user and self.smtp_auth_password:
99+
s.login(user=self.smtp_auth_user, password=self.smtp_auth_password)
100+
logging.debug("Authenticated")
101+
102+
s.connect()
103+
104+
return s
105+
106+
def _send(self, mail_from, mail_to, mail_subject, msg):
79107
"""Actually send the mail.
80108
81109
@type mail_from: string representing the sender.
@@ -85,13 +113,8 @@ def _send(self,
85113
"""
86114

87115
try:
88-
if self.mail_host:
89-
logging.debug("Using %s as the mail host", self.mail_host)
90-
s = smtplib.SMTP(self.mail_host)
91-
else:
92-
logging.debug("Using the default mail host")
93-
s = smtplib.SMTP()
94-
s.connect()
116+
s = self._connect()
117+
95118
try:
96119
s.sendmail(mail_from, mail_to, msg.as_string())
97120
except smtplib.SMTPHeloError as err:
@@ -105,27 +128,27 @@ def _send(self,
105128
raise
106129
except smtplib.SMTPDataError as err:
107130
raise
131+
108132
except smtplib.SMTPConnectError as err:
109133
logging.exception("Cannot connect to the SMTP host %s", self.mail_host)
110-
raise VscMailError(mail_host=self.mail_host,
111-
mail_to=mail_to,
112-
mail_from=mail_from,
113-
mail_subject=mail_subject,
114-
err=err)
134+
raise VscMailError(
135+
mail_host=self.mail_host,
136+
mail_to=mail_to,
137+
mail_from=mail_from,
138+
mail_subject=mail_subject,
139+
err=err)
115140
except Exception as err:
116141
logging.exception("Some unknown exception occurred in VscMail.sendTextMail. Raising a VscMailError.")
117-
raise VscMailError(mail_host=self.mail_host,
118-
mail_to=mail_to,
119-
mail_from=mail_from,
120-
mail_subject=mail_subject,
121-
err=err)
122-
123-
def sendTextMail(self,
124-
mail_to,
125-
mail_from,
126-
reply_to,
127-
mail_subject,
128-
message):
142+
raise VscMailError(
143+
mail_host=self.mail_host,
144+
mail_to=mail_to,
145+
mail_from=mail_from,
146+
mail_subject=mail_subject,
147+
err=err)
148+
else:
149+
s.quit()
150+
151+
def sendTextMail(self, mail_to, mail_from, reply_to, mail_subject, message):
129152
"""Send out the given message by mail to the given recipient(s).
130153
131154
@type mail_to: string or list of strings
@@ -174,15 +197,16 @@ def _replace_images_cid(self, html, images):
174197

175198
return html
176199

177-
def sendHTMLMail(self,
178-
mail_to,
179-
mail_from,
180-
reply_to,
181-
mail_subject,
182-
html_message,
183-
text_alternative,
184-
images=None,
185-
css=None):
200+
def sendHTMLMail(
201+
self,
202+
mail_to,
203+
mail_from,
204+
reply_to,
205+
mail_subject,
206+
html_message,
207+
text_alternative,
208+
images=None,
209+
css=None):
186210
"""
187211
Send an HTML email message, encoded in a MIME/multipart message.
188212
@@ -222,7 +246,7 @@ def sendHTMLMail(self,
222246

223247
# Create the body of the message (a plain-text and an HTML version).
224248
if images is not None:
225-
html_message = self.replace_images_cid(html_message, images)
249+
html_message = self._replace_images_cid(html_message, images)
226250

227251
# Record the MIME types of both parts - text/plain and text/html_message.
228252
msg_plain = MIMEText(text_alternative, 'plain')

test/exceptions.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ def test_loggedexception_defaultlogger(self):
6262
logToFile(tmplog, enable=False)
6363

6464
log_re = re.compile("^%s :: BOOM( \(at .*:[0-9]+ in raise_loggedexception\))?$" % getRootLoggerName(), re.M)
65-
logtxt = open(tmplog, 'r').read()
66-
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
65+
with open(tmplog, 'r') as f:
66+
logtxt = f.read()
67+
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
6768

6869
# test formatting of message
6970
self.assertErrorRegex(LoggedException, 'BOOMBAF', raise_loggedexception, 'BOOM%s', 'BAF')
@@ -92,8 +93,9 @@ def test_loggedexception_specifiedlogger(self):
9293

9394
rootlog = getRootLoggerName()
9495
log_re = re.compile("^%s.testlogger_one :: BOOM( \(at .*:[0-9]+ in raise_loggedexception\))?$" % rootlog, re.M)
95-
logtxt = open(tmplog, 'r').read()
96-
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
96+
with open(tmplog, 'r') as f:
97+
logtxt = f.read()
98+
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
9799

98100
os.remove(tmplog)
99101

@@ -114,8 +116,9 @@ def test_loggedexception_callerlogger(self):
114116

115117
rootlog = getRootLoggerName()
116118
log_re = re.compile("^%s(.testlogger_local)? :: BOOM( \(at .*:[0-9]+ in raise_loggedexception\))?$" % rootlog)
117-
logtxt = open(tmplog, 'r').read()
118-
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
119+
with open(tmplog, 'r') as f:
120+
logtxt = f.read()
121+
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
119122

120123
os.remove(tmplog)
121124

@@ -143,12 +146,12 @@ def raise_testexception(msg, *args, **kwargs):
143146
rootlogname = getRootLoggerName()
144147

145148
log_re = re.compile("^%s :: BOOM$" % rootlogname, re.M)
146-
logtxt = open(tmplog, 'r').read()
147-
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
149+
with open(tmplog, 'r') as f:
150+
logtxt = f.read()
151+
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
148152

149-
f = open(tmplog, 'w')
150-
f.write('')
151-
f.close()
153+
with open(tmplog, 'w') as f:
154+
f.write('')
152155

153156
# location is included if LOC_INFO_TOP_PKG_NAMES is defined
154157
TestException.LOC_INFO_TOP_PKG_NAMES = ['vsc']
@@ -157,12 +160,12 @@ def raise_testexception(msg, *args, **kwargs):
157160
logToFile(tmplog, enable=False)
158161

159162
log_re = re.compile(r"^%s :: BOOM \(at (?:.*?/)?vsc/install/testing.py:[0-9]+ in assertErrorRegex\)$" % rootlogname)
160-
logtxt = open(tmplog, 'r').read()
161-
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
163+
with open(tmplog, 'r') as f:
164+
logtxt = f.read()
165+
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
162166

163-
f = open(tmplog, 'w')
164-
f.write('')
165-
f.close()
167+
with open(tmplog, 'w') as f:
168+
f.write('')
166169

167170
# absolute path of location is included if there's no match in LOC_INFO_TOP_PKG_NAMES
168171
TestException.LOC_INFO_TOP_PKG_NAMES = ['foobar']
@@ -171,8 +174,9 @@ def raise_testexception(msg, *args, **kwargs):
171174
logToFile(tmplog, enable=False)
172175

173176
log_re = re.compile(r"^%s :: BOOM \(at (?:.*?/)?vsc/install/testing.py:[0-9]+ in assertErrorRegex\)$" % rootlogname)
174-
logtxt = open(tmplog, 'r').read()
175-
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
177+
with open(tmplog, 'r') as f:
178+
logtxt = f.read()
179+
self.assertTrue(log_re.match(logtxt), "%s matches %s" % (log_re.pattern, logtxt))
176180

177181
os.remove(tmplog)
178182

test/mail.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#
2+
# Copyright 2014-2021 Ghent University
3+
#
4+
# This file is part of vsc-base,
5+
# originally created by the HPC team of Ghent University (http://ugent.be/hpc/en),
6+
# with support of Ghent University (http://ugent.be/hpc),
7+
# the Flemish Supercomputer Centre (VSC) (https://www.vscentrum.be),
8+
# the Flemish Research Foundation (FWO) (http://www.fwo.be/en)
9+
# and the Department of Economy, Science and Innovation (EWI) (http://www.ewi-vlaanderen.be/en).
10+
#
11+
# https://github.com/hpcugent/vsc-base
12+
#
13+
# vsc-base is free software: you can redistribute it and/or modify
14+
# it under the terms of the GNU Library General Public License as
15+
# published by the Free Software Foundation, either version 2 of
16+
# the License, or (at your option) any later version.
17+
#
18+
# vsc-base is distributed in the hope that it will be useful,
19+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
20+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21+
# GNU Library General Public License for more details.
22+
#
23+
# You should have received a copy of the GNU Library General Public License
24+
# along with vsc-base. If not, see <http://www.gnu.org/licenses/>.
25+
#
26+
"""
27+
Unit tests for the mail wrapper.
28+
29+
@author: Andy Georges (Ghent University)
30+
"""
31+
import mock
32+
import os
33+
34+
from vsc.install.testing import TestCase
35+
36+
from email.mime.text import MIMEText
37+
from vsc.utils.mail import VscMail
38+
39+
class TestVscMail(TestCase):
40+
41+
@mock.patch('vsc.utils.mail.smtplib')
42+
@mock.patch('vsc.utils.mail.ssl')
43+
def test_send(self, mock_ssl, mock_smtplib):
44+
45+
msg = MIMEText("test")
46+
msg['Subject'] = "subject"
47+
msg['From'] = "[email protected]"
48+
msg['To'] = "[email protected]"
49+
msg['Reply-to'] = "[email protected]"
50+
51+
vm = VscMail()
52+
53+
self.assertEqual(vm.mail_host, '')
54+
self.assertEqual(vm.mail_port, 0)
55+
self.assertEqual(vm.smtp_auth_user, None)
56+
self.assertEqual(vm.smtp_auth_password, None)
57+
self.assertEqual(vm.smtp_use_starttls, False)
58+
59+
vm._send(mail_from="[email protected]", mail_to="[email protected]", mail_subject="s", msg=msg)
60+
61+
vm = VscMail(
62+
mail_host = "test.machine.com",
63+
mail_port=123,
64+
smtp_auth_user="me",
65+
smtp_auth_password="hunter2",
66+
)
67+
68+
self.assertEqual(vm.mail_host, "test.machine.com")
69+
self.assertEqual(vm.mail_port, 123)
70+
self.assertEqual(vm.smtp_auth_user, "me")
71+
self.assertEqual(vm.smtp_auth_password, "hunter2")
72+
self.assertEqual(vm.smtp_use_starttls, False)
73+
74+
vm._send(mail_from="[email protected]", mail_to="[email protected]", mail_subject="s", msg=msg)
75+
76+
mock_smtplib.SMTP.assert_called_with(host="test.machine.com", port=123)
77+
78+
vm = VscMail(
79+
mail_host = "test.machine.com",
80+
mail_port=124,
81+
smtp_auth_user="me",
82+
smtp_auth_password="hunter2",
83+
smtp_use_starttls=True
84+
)
85+
86+
self.assertEqual(vm.mail_host, "test.machine.com")
87+
self.assertEqual(vm.mail_port, 124)
88+
self.assertEqual(vm.smtp_auth_user, "me")
89+
self.assertEqual(vm.smtp_auth_password, "hunter2")
90+
self.assertEqual(vm.smtp_use_starttls, True)
91+
92+
vm._send(mail_from="[email protected]", mail_to="[email protected]", mail_subject="s", msg=msg)
93+
94+
mock_smtplib.SMTP.assert_called_with(host="test.machine.com", port=124)
95+
mock_ssl.create_default_context.assert_called()

0 commit comments

Comments
 (0)