Skip to content

Commit db96f39

Browse files
committed
reworked test_clients
1 parent c0b5a02 commit db96f39

File tree

1 file changed

+16
-44
lines changed

1 file changed

+16
-44
lines changed
Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import importlib
2-
import logging
31
import unittest
42
from unittest.mock import MagicMock, patch
53

6-
import common.clients as clients
4+
from moto import mock_aws
75

6+
import common.clients
87

8+
9+
@mock_aws
910
class TestClients(unittest.TestCase):
1011
BUCKET_NAME = "default-bucket"
1112
AWS_REGION = "eu-west-2"
@@ -14,17 +15,14 @@ def setUp(self):
1415
# Patch boto3.client
1516
self.boto3_client_patch = patch("boto3.client", autospec=True)
1617
self.mock_boto3_client = self.boto3_client_patch.start()
17-
self.addCleanup(self.boto3_client_patch.stop)
1818

1919
# Patch logging.getLogger
2020
self.logging_patch = patch("logging.getLogger", autospec=True)
2121
self.mock_getLogger = self.logging_patch.start()
22-
self.addCleanup(self.logging_patch.stop)
2322

2423
# Patch os.getenv
2524
self.getenv_patch = patch("os.getenv", autospec=True)
2625
self.mock_getenv = self.getenv_patch.start()
27-
self.addCleanup(self.getenv_patch.stop)
2826

2927
# Set environment variable mock return values
3028
self.mock_getenv.side_effect = lambda key, default=None: {
@@ -36,57 +34,31 @@ def setUp(self):
3634
self.mock_logger_instance = MagicMock()
3735
self.mock_getLogger.return_value = self.mock_logger_instance
3836

39-
# Reload the module under test to apply patches
40-
importlib.reload(clients)
41-
42-
def test_env_variables_loaded(self):
43-
"""Test that environment variables are loaded correctly"""
44-
self.assertEqual(clients.CONFIG_BUCKET_NAME, self.BUCKET_NAME)
45-
self.assertEqual(clients.REGION_NAME, self.AWS_REGION)
46-
47-
def test_boto3_client_created_for_s3(self):
48-
"""Test that S3 boto3 client is created with correct region"""
49-
importlib.reload(clients)
50-
clients.get_s3_client()
51-
self.mock_boto3_client.assert_any_call("s3", region_name=self.AWS_REGION)
52-
53-
def test_boto3_client_created_for_firehose(self):
54-
"""Test that Firehose boto3 client is created with correct region"""
55-
self.mock_boto3_client.assert_any_call("firehose", region_name=self.AWS_REGION)
56-
57-
def test_logger_is_initialized(self):
58-
"""Test that a logger instance is initialized"""
59-
self.mock_getLogger.assert_called_once_with()
60-
self.assertTrue(hasattr(clients, "logger"))
61-
62-
def test_logger_set_level(self):
63-
"""Test that logger level is set to INFO"""
64-
self.mock_logger_instance.setLevel.assert_called_once_with(logging.INFO)
37+
def tearDown(self):
38+
self.getenv_patch.stop()
39+
self.logging_patch.stop()
40+
self.boto3_client_patch.stop()
6541

6642
def test_global_s3_client(self):
6743
"""Test global_s3_client is not initialized on import"""
68-
importlib.reload(clients)
69-
self.assertEqual(clients.global_s3_client, None)
44+
self.assertEqual(common.clients.global_s3_client, None)
7045

7146
def test_global_s3_client_initialization(self):
7247
"""Test global_s3_client is initialized exactly once even with multiple invocations"""
73-
importlib.reload(clients)
74-
clients.get_s3_client()
75-
self.assertNotEqual(clients.global_s3_client, None)
48+
common.clients.get_s3_client()
49+
self.assertNotEqual(common.clients.global_s3_client, None)
7650
call_count = self.mock_boto3_client.call_count
77-
clients.get_s3_client()
51+
common.clients.get_s3_client()
7852
self.assertEqual(self.mock_boto3_client.call_count, call_count)
7953

8054
def test_global_sqs_client(self):
8155
"""Test global_sqs_client is not initialized on import"""
82-
importlib.reload(clients)
83-
self.assertEqual(clients.global_sqs_client, None)
56+
self.assertEqual(common.clients.global_sqs_client, None)
8457

8558
def test_global_sqs_client_initialization(self):
8659
"""Test global_sqs_client is initialized exactly once even with multiple invocations"""
87-
importlib.reload(clients)
88-
clients.get_sqs_client()
89-
self.assertNotEqual(clients.global_sqs_client, None)
60+
common.clients.get_sqs_client()
61+
self.assertNotEqual(common.clients.global_sqs_client, None)
9062
call_count = self.mock_boto3_client.call_count
91-
clients.get_sqs_client()
63+
common.clients.get_sqs_client()
9264
self.assertEqual(self.mock_boto3_client.call_count, call_count)

0 commit comments

Comments
 (0)