Skip to content

Commit 790d2e8

Browse files
authored
[change] Allow counters to return multiple replies #634
Closes #634
1 parent 340a64e commit 790d2e8

File tree

8 files changed

+137
-51
lines changed

8 files changed

+137
-51
lines changed

openwisp_radius/api/serializers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from .. import settings as app_settings
3535
from ..base.forms import PasswordResetForm
36-
from ..counters.exceptions import MaxQuotaReached, SkipCheck
36+
from ..counters.exceptions import SkipCheck
3737
from ..registration import REGISTRATION_METHOD_CHOICES
3838
from ..utils import (
3939
get_group_checks,
@@ -304,12 +304,11 @@ def get_result(self, obj):
304304
group=self.context["group"],
305305
group_check=obj,
306306
)
307-
# Python can handle 64 bit numbers and
308-
# hence we don't need to display Gigawords
309-
remaining = counter.check(gigawords=False)
310-
return int(obj.value) - remaining
311-
except MaxQuotaReached:
312-
return int(obj.value)
307+
consumed = counter.consumed()
308+
value = int(obj.value)
309+
if consumed > value:
310+
consumed = value
311+
return consumed
313312
except (SkipCheck, ValueError, KeyError):
314313
return None
315314

openwisp_radius/counters/base.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,18 @@ def check_name(self): # pragma: no cover
2121
pass
2222

2323
@property
24-
@abstractmethod
25-
def reply_name(self): # pragma: no cover
26-
pass
24+
def reply_names(self):
25+
# BACKWARD COMPATIBILITY: In previous versions of openwisp-radius,
26+
# the Counter.reply_name was a string instead of a tuple. Thus,
27+
# we need to convert it to a tuple if it's a string.
28+
reply_name = getattr(self, "reply_name", None)
29+
if not reply_name:
30+
raise NotImplementedError(
31+
"Counter classes must define 'reply_names' property."
32+
)
33+
if isinstance(reply_name, str):
34+
return (reply_name,)
35+
return reply_name
2736

2837
@property
2938
@abstractmethod
@@ -43,7 +52,6 @@ def get_sql_params(self, start_time, end_time): # pragma: no cover
4352
# sqlcounter module, now we can translate it with gettext
4453
# or customize it (in new counter classes) if needed
4554
reply_message = _("Your maximum daily usage time has been reached")
46-
gigawords = False
4755

4856
def __init__(self, user, group, group_check):
4957
self.user = user
@@ -72,7 +80,7 @@ def get_attribute_type(self):
7280

7381
def get_reset_timestamps(self):
7482
try:
75-
return resets[self.reset](self.user)
83+
return resets[self.reset](self.user, counter=self)
7684
except KeyError:
7785
raise SkipCheck(
7886
message=f'Reset time with key "{self.reset}" not available.',
@@ -93,7 +101,7 @@ def get_counter(self):
93101
# or if nothing is returned (no sessions present), return zero
94102
return row[0] or 0
95103

96-
def check(self, gigawords=gigawords):
104+
def check(self):
97105
if not self.group_check:
98106
raise SkipCheck(
99107
message=(
@@ -134,12 +142,15 @@ def check(self, gigawords=gigawords):
134142
reply_message=self.reply_message,
135143
)
136144

137-
return int(remaining)
145+
return (int(remaining),)
146+
147+
def consumed(self):
148+
return int(self.get_counter())
138149

139150

140151
class BaseDailyCounter(BaseCounter):
141152
check_name = "Max-Daily-Session"
142-
reply_name = "Session-Timeout"
153+
reply_names = ("Session-Timeout",)
143154
reset = "daily"
144155

145156
def get_sql_params(self, start_time, end_time):
@@ -152,7 +163,7 @@ def get_sql_params(self, start_time, end_time):
152163

153164

154165
class BaseTrafficCounter(BaseCounter):
155-
reply_name = app_settings.TRAFFIC_COUNTER_REPLY_NAME
166+
reply_names = (app_settings.TRAFFIC_COUNTER_REPLY_NAME,)
156167

157168
def get_sql_params(self, start_time, end_time):
158169
return [

openwisp_radius/counters/resets.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,29 @@ def _timestamp(start, end):
1111
return int(start.timestamp()), int(end.timestamp())
1212

1313

14-
def _daily(user=None):
14+
def _daily(user=None, **kwargs):
1515
dt = _today()
1616
start = datetime(dt.year, dt.month, dt.day)
1717
end = datetime(dt.year, dt.month, dt.day) + timedelta(days=1)
1818
return _timestamp(start, end)
1919

2020

21-
def _weekly(user=None):
21+
def _weekly(user=None, **kwargs):
2222
dt = _today()
2323
start = dt - timedelta(days=dt.weekday())
2424
start = datetime(start.year, start.month, start.day)
2525
end = start + timedelta(days=7)
2626
return _timestamp(start, end)
2727

2828

29-
def _monthly(user=None):
29+
def _monthly(user=None, **kwargs):
3030
dt = _today()
3131
start = datetime(dt.year, dt.month, 1)
3232
end = datetime(dt.year, dt.month, 1) + relativedelta(months=1)
3333
return _timestamp(start, end)
3434

3535

36-
def _monthly_subscription(user):
36+
def _monthly_subscription(user, **kwargs):
3737
dt = _today()
3838
day_joined = user.date_joined.day
3939
# subscription cycle starts on the day of month the user joined
@@ -45,7 +45,7 @@ def _monthly_subscription(user):
4545
return _timestamp(start, end)
4646

4747

48-
def _never(user=None):
48+
def _never(user=None, **kwargs):
4949
return 0, None
5050

5151

openwisp_radius/tests/test_api/test_api.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,44 @@ def test_user_radius_usage_view(self):
11701170
},
11711171
)
11721172

1173+
data3 = self.acct_post_data
1174+
data3.update(
1175+
dict(
1176+
session_id="40111117",
1177+
unique_id="12234f70",
1178+
input_octets=1000000000,
1179+
output_octets=1000000000,
1180+
username="tester",
1181+
)
1182+
)
1183+
self._create_radius_accounting(**data3)
1184+
1185+
with self.subTest("User consumed more than allowed limit"):
1186+
response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization)
1187+
self.assertEqual(response.status_code, 200)
1188+
self.assertIn("checks", response.data)
1189+
checks = response.data["checks"]
1190+
self.assertDictEqual(
1191+
dict(checks[0]),
1192+
{
1193+
"attribute": "Max-Daily-Session",
1194+
"op": ":=",
1195+
"value": "10800",
1196+
"result": 783,
1197+
"type": "seconds",
1198+
},
1199+
)
1200+
self.assertDictEqual(
1201+
dict(checks[1]),
1202+
{
1203+
"attribute": "Max-Daily-Session-Traffic",
1204+
"op": ":=",
1205+
"value": "3000000000",
1206+
"result": 3000000000,
1207+
"type": "bytes",
1208+
},
1209+
)
1210+
11731211
with self.subTest("Test user does not have RadiusUserGroup"):
11741212
RadiusUserGroup.objects.all().delete()
11751213
response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization)

openwisp_radius/tests/test_counters/test_base_counter.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,37 @@ def test_abstract_instantiation(self):
4242
BaseCounter(**opts)
4343
self.assertIn("abstract class BaseCounter", str(ctx.exception))
4444

45+
def test_reply_name_backward_compatibility(self):
46+
options = self._get_kwargs("Session-Timeout")
47+
48+
class BackwardCompatibleCounter(BaseCounter):
49+
check_name = "Max-Daily-Session"
50+
counter_name = "BackwardCompatibleCounter"
51+
reset = "daily"
52+
sql = "SELECT 1"
53+
54+
def get_sql_params(self, start_time, end_time):
55+
return []
56+
57+
with self.subTest("Counter does not implement reply_names or reply_name"):
58+
counter = BackwardCompatibleCounter(**options)
59+
with self.assertRaises(NotImplementedError) as ctx:
60+
counter.reply_names
61+
self.assertIn(
62+
"Counter classes must define 'reply_names' property.",
63+
str(ctx.exception),
64+
)
65+
66+
BackwardCompatibleCounter.reply_name = "Session-Timeout"
67+
with self.subTest("Counter does not implement reply_names, uses reply_name"):
68+
counter = BackwardCompatibleCounter(**options)
69+
self.assertEqual(counter.reply_names, ("Session-Timeout",))
70+
71+
BackwardCompatibleCounter.reply_name = ("Session-Timeout",)
72+
with self.subTest("Counter implements reply_names as tuple"):
73+
counter = BackwardCompatibleCounter(**options)
74+
self.assertEqual(counter.reply_names, ("Session-Timeout",))
75+
4576
@freeze_time("2021-11-03T08:21:44-04:00")
4677
def test_resets(self):
4778
with self.subTest("daily"):

openwisp_radius/tests/test_counters/test_sqlite_counters.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,22 @@ def test_time_counter_repr(self):
2828
def test_time_counter_no_sessions(self):
2929
opts = self._get_kwargs("Max-Daily-Session")
3030
counter = DailyCounter(**opts)
31-
self.assertEqual(counter.check(), int(opts["group_check"].value))
31+
self.assertEqual(counter.check(), (int(opts["group_check"].value),))
3232

3333
def test_time_counter_with_sessions(self):
3434
opts = self._get_kwargs("Max-Daily-Session")
3535
counter = DailyCounter(**opts)
3636
self._create_radius_accounting(**_acct_data)
3737
expected = int(opts["group_check"].value) - int(_acct_data["session_time"])
38-
self.assertEqual(counter.check(), expected)
38+
self.assertEqual(counter.check(), (expected,))
3939
_acct_data2 = _acct_data.copy()
4040
_acct_data2.update({"session_id": "2", "unique_id": "2", "session_time": "500"})
4141
self._create_radius_accounting(**_acct_data2)
4242
session_time = int(_acct_data["session_time"]) + int(
4343
_acct_data2["session_time"]
4444
)
4545
expected = int(opts["group_check"].value) - session_time
46-
self.assertEqual(counter.check(), expected)
46+
self.assertEqual(counter.check(), (expected,))
4747

4848
@capture_any_output()
4949
def test_counter_skip_exceptions(self):
@@ -88,7 +88,7 @@ def test_counter_skip_exceptions(self):
8888
def test_traffic_counter_no_sessions(self):
8989
opts = self._get_kwargs("Max-Daily-Session-Traffic")
9090
counter = DailyTrafficCounter(**opts)
91-
self.assertEqual(counter.check(), int(opts["group_check"].value))
91+
self.assertEqual(counter.check(), (int(opts["group_check"].value),))
9292

9393
def test_traffic_counter_with_sessions(self):
9494
opts = self._get_kwargs("Max-Daily-Session-Traffic")
@@ -98,13 +98,13 @@ def test_traffic_counter_with_sessions(self):
9898
self._create_radius_accounting(**acct)
9999
traffic = int(acct["input_octets"]) + int(acct["output_octets"])
100100
expected = int(opts["group_check"].value) - traffic
101-
self.assertEqual(counter.check(), expected)
101+
self.assertEqual(counter.check(), (expected,))
102102

103103
def test_traffic_counter_reply_and_check_name(self):
104104
opts = self._get_kwargs("Max-Daily-Session-Traffic")
105105
counter = DailyTrafficCounter(**opts)
106106
self.assertEqual(counter.check_name, "Max-Daily-Session-Traffic")
107-
self.assertEqual(counter.reply_name, "CoovaChilli-Max-Total-Octets")
107+
self.assertEqual(counter.reply_names[0], "CoovaChilli-Max-Total-Octets")
108108

109109
def test_monthly_traffic_counter_with_sessions(self):
110110
rg = RadiusGroup.objects.filter(name="test-org-users").first()
@@ -121,7 +121,7 @@ def test_monthly_traffic_counter_with_sessions(self):
121121
self._create_radius_accounting(**acct)
122122
traffic = int(acct["input_octets"]) + int(acct["output_octets"])
123123
expected = int(opts["group_check"].value) - traffic
124-
self.assertEqual(counter.check(), expected)
124+
self.assertEqual(counter.check(), (expected,))
125125

126126

127127
del BaseTransactionTestCase

openwisp_radius/tests/test_selenium.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.contrib.auth import get_user_model
22
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
3+
from django.test import tag
34
from django.urls import reverse
45
from selenium.webdriver.common.by import By
56
from selenium.webdriver.support.ui import Select
@@ -15,6 +16,7 @@
1516
OrganizationRadiusSettings = load_model("OrganizationRadiusSettings")
1617

1718

19+
@tag("selenium_tests")
1820
class BasicTest(
1921
SeleniumTestMixin, FileMixin, StaticLiveServerTestCase, TestOrganizationMixin
2022
):
@@ -44,8 +46,7 @@ def test_batch_user_creation(self):
4446
# Select the previously created organization
4547
option = self.find_element(
4648
By.XPATH,
47-
"//li[contains(@class, 'select2-results__option') and "
48-
"text()='test org']",
49+
"//li[contains(@class, 'select2-results__option') and text()='test org']",
4950
10,
5051
)
5152
option.click()
@@ -87,8 +88,7 @@ def test_standard_csv_import(self):
8788
organization.click()
8889
option = self.find_element(
8990
By.XPATH,
90-
"//li[contains(@class, 'select2-results__option') and "
91-
"text()='test org']",
91+
"//li[contains(@class, 'select2-results__option') and text()='test org']",
9292
10,
9393
)
9494
option.click()
@@ -135,8 +135,7 @@ def test_import_with_hashed_passwords(self):
135135
organization.click()
136136
option = self.find_element(
137137
By.XPATH,
138-
"//li[contains(@class, 'select2-results__option') and "
139-
"text()='test org']",
138+
"//li[contains(@class, 'select2-results__option') and text()='test org']",
140139
10,
141140
)
142141
option.click()
@@ -179,8 +178,7 @@ def test_csv_user_generation(self):
179178
organization.click()
180179
option = self.find_element(
181180
By.XPATH,
182-
"//li[contains(@class, 'select2-results__option') and "
183-
"text()='test org']",
181+
"//li[contains(@class, 'select2-results__option') and text()='test org']",
184182
10,
185183
)
186184
option.click()

0 commit comments

Comments
 (0)