Skip to content

Commit a73171f

Browse files
committed
user type in rate limits config
1 parent 31771b0 commit a73171f

File tree

4 files changed

+82
-49
lines changed

4 files changed

+82
-49
lines changed

cads_processing_api_service/config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class RateLimitsRouteParamConfig(pydantic.BaseModel):
116116
model_config = pydantic.ConfigDict(extra="allow")
117117

118118

119-
class RateLimitsConfig(pydantic.BaseModel):
119+
class RateLimitsUserConfig(pydantic.BaseModel):
120120
default: RateLimitsRouteConfig = pydantic.Field(
121121
default=RateLimitsRouteConfig(), validate_default=True
122122
)
@@ -151,6 +151,18 @@ class RateLimitsConfig(pydantic.BaseModel):
151151
)
152152

153153

154+
class RateLimitsConfig(pydantic.BaseModel):
155+
"""Rate limits configuration."""
156+
157+
auth: RateLimitsUserConfig = pydantic.Field(
158+
default=RateLimitsUserConfig(),
159+
description="Rate limits for authenticated users.",
160+
)
161+
anon: RateLimitsUserConfig = pydantic.Field(
162+
default=RateLimitsUserConfig(), description="Rate limits for anonymous users."
163+
)
164+
165+
154166
def load_rate_limits(rate_limits_file: str | None) -> RateLimitsConfig:
155167
rate_limits = RateLimitsConfig()
156168
if rate_limits_file is not None:

cads_processing_api_service/limits.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,16 @@
3131

3232
def get_rate_limits(
3333
rate_limits_config: config.RateLimitsConfig,
34+
user_type: str,
3435
route: str,
3536
method: str,
3637
request_origin: str,
3738
route_param: str | None = None,
3839
) -> list[str]:
3940
"""Get the rate limits for a specific route and method."""
4041
rate_limits = rate_limits_config.model_dump()
41-
route_rate_limits: dict[str, Any] = rate_limits.get(route, {})
42+
user_type_rate_limits: dict[str, Any] = rate_limits.get(user_type, {})
43+
route_rate_limits: dict[str, Any] = user_type_rate_limits.get(route, {})
4244
if route_param is not None:
4345
route_param_rate_limits: dict[str, Any] = route_rate_limits.get(route_param, {})
4446
else:
@@ -50,22 +52,23 @@ def get_rate_limits(
5052

5153
def get_rate_limits_defaulted(
5254
rate_limits_config: config.RateLimitsConfig,
55+
user_type: str,
5356
route: str,
5457
method: str,
5558
request_origin: str,
5659
route_param: str | None = None,
5760
) -> list[str]:
5861
"""Get the rate limits for a specific route and method, with defaults."""
5962
rate_limits = get_rate_limits(
60-
rate_limits_config, route, method, request_origin, route_param
63+
rate_limits_config, user_type, route, method, request_origin, route_param
6164
)
6265
if not rate_limits:
6366
rate_limits = get_rate_limits(
64-
rate_limits_config, route, method, request_origin, "default"
67+
rate_limits_config, user_type, route, method, request_origin, "default"
6568
)
6669
if not rate_limits:
6770
rate_limits = get_rate_limits(
68-
rate_limits_config, "default", method, request_origin
71+
rate_limits_config, user_type, "default", method, request_origin
6972
)
7073
return rate_limits
7174

@@ -104,8 +107,9 @@ def check_rate_limits(
104107
"""Check if the rate limits are exceeded."""
105108
request_origin = auth_info.request_origin
106109
user_uid = auth_info.user_uid
110+
user_type = "anon" if auth_info.user_uid == "unauthenticated" else "auth"
107111
rate_limits = get_rate_limits_defaulted(
108-
rate_limits_config, route, method, request_origin, route_param
112+
rate_limits_config, user_type, route, method, request_origin, route_param
109113
)
110114
rate_limits_parsed = [limits.parse(rate_limit) for rate_limit in rate_limits]
111115
check_rate_limits_for_user(user_uid, rate_limits_parsed)

tests/test_10_config.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,13 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None:
7373

7474
rate_limits_file = str(tmp_path / "rate-limits.yaml")
7575
rate_limits = {
76-
"/jobs/{job_id}": {"get": {"api": ["1/second"], "ui": ["2/second"]}},
77-
"/processes/{process_id}/constraints": {
78-
"default": {"get": {"api": ["1/second"], "ui": ["2/second"]}},
79-
"process-id": {"post": {"api": ["1/second"], "ui": ["2/second"]}},
80-
},
76+
"auth": {
77+
"/jobs/{job_id}": {"get": {"api": ["1/second"], "ui": ["2/second"]}},
78+
"/processes/{process_id}/constraints": {
79+
"default": {"get": {"api": ["1/second"], "ui": ["2/second"]}},
80+
"process-id": {"post": {"api": ["1/second"], "ui": ["2/second"]}},
81+
},
82+
}
8183
}
8284
with open(rate_limits_file, "w") as file:
8385
yaml.dump(rate_limits, file)
@@ -87,7 +89,8 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None:
8789
"post": {"api": [], "ui": []},
8890
"delete": {"api": [], "ui": []},
8991
}
90-
assert loaded_rate_limits["jobs_jobsid"] == expected_jobs_limits
92+
assert "auth" in loaded_rate_limits
93+
assert loaded_rate_limits["auth"]["jobs_jobsid"] == expected_jobs_limits
9194
expected_process_constraints_limits = {
9295
"default": {
9396
"get": {"api": ["1/second"], "ui": ["2/second"]},
@@ -101,13 +104,15 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None:
101104
},
102105
}
103106
assert (
104-
loaded_rate_limits["processes_processid_constraints"]
107+
loaded_rate_limits["auth"]["processes_processid_constraints"]
105108
== expected_process_constraints_limits
106109
)
107110

108111
rate_limits_file = str(tmp_path / "invalid-rate-limits.yaml")
109112
rate_limits = {
110-
"/jobs/{job_id}": {"get": {"api": ["invalid_limit"]}},
113+
"auth": {
114+
"/jobs/{job_id}": {"get": {"api": ["invalid_limit"]}},
115+
}
111116
}
112117
with open(rate_limits_file, "w") as file:
113118
yaml.dump(rate_limits, file)

tests/test_30_limits.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,68 +22,76 @@
2222

2323

2424
def test_get_rate_limits() -> None:
25-
rate_limits = {"/jobs/{job_id}": {"get": {"api": ["2/second"]}}}
25+
rate_limits = {"auth": {"/jobs/{job_id}": {"get": {"api": ["2/second"]}}}}
2626
rate_limits_config = config.RateLimitsConfig(**rate_limits)
2727

28+
user_type = "auth"
2829
route = "jobs_jobsid"
2930
method = "get"
3031
request_origin = "api"
3132
rate_limits = cads_processing_api_service.limits.get_rate_limits(
32-
rate_limits_config, route, method, request_origin
33+
rate_limits_config, user_type, route, method, request_origin
3334
)
3435
exp_rate_limits = ["2/second"]
3536
assert rate_limits == exp_rate_limits
3637

3738

3839
def test_get_rate_limits_route_param() -> None:
3940
rate_limits = {
40-
"/processes/{process_id}/execution": {
41-
"process_id": {"post": {"api": ["2/second"]}}
41+
"auth": {
42+
"/processes/{process_id}/execution": {
43+
"process_id": {"post": {"api": ["2/second"]}}
44+
}
4245
}
4346
}
4447
rate_limits_config = config.RateLimitsConfig(**rate_limits)
4548

49+
user_type = "auth"
4650
route = "processes_processid_execution"
4751
route_param = "process_id"
4852
method = "post"
4953
request_origin = "api"
5054
rate_limits = cads_processing_api_service.limits.get_rate_limits(
51-
rate_limits_config, route, method, request_origin, route_param
55+
rate_limits_config, user_type, route, method, request_origin, route_param
5256
)
5357
exp_rate_limits = ["2/second"]
5458
assert rate_limits == exp_rate_limits
5559

5660

5761
def test_get_rate_limits_defaulted_actual_value() -> None:
5862
rate_limits = {
59-
"/jobs/{job_id}": {"get": {"api": ["2/second"]}},
60-
"default": {"get": {"api": ["1/second"]}},
63+
"auth": {
64+
"/jobs/{job_id}": {"get": {"api": ["2/second"]}},
65+
"default": {"get": {"api": ["1/second"]}},
66+
}
6167
}
6268
rate_limits_config = config.RateLimitsConfig(**rate_limits)
63-
69+
user_type = "auth"
6470
route = "jobs_jobsid"
6571
method = "get"
6672
request_origin = "api"
6773
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
68-
rate_limits_config, route, method, request_origin
74+
rate_limits_config, user_type, route, method, request_origin
6975
)
7076
exp_rate_limits = ["2/second"]
7177
assert rate_limits == exp_rate_limits
7278

7379

7480
def test_get_rate_limits_defaulted_default_value() -> None:
7581
rate_limits = {
76-
"/jobs/{job_id}": {"post": {"api": ["2/second"]}},
77-
"/jobs": {"get": {"api": ["2/second"]}},
78-
"default": {"post": {"ui": ["1/second"]}},
82+
"auth": {
83+
"/jobs/{job_id}": {"post": {"api": ["2/second"]}},
84+
"/jobs": {"get": {"api": ["2/second"]}},
85+
"default": {"post": {"ui": ["1/second"]}},
86+
}
7987
}
8088
rate_limits_config = config.RateLimitsConfig(**rate_limits)
81-
89+
user_type = "auth"
8290
route = "jobs_jobsid"
8391
method = "post"
8492
request_origin = "ui"
8593
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
86-
rate_limits_config, route, method, request_origin
94+
rate_limits_config, user_type, route, method, request_origin
8795
)
8896
exp_rate_limits = ["1/second"]
8997
assert rate_limits == exp_rate_limits
@@ -92,7 +100,7 @@ def test_get_rate_limits_defaulted_default_value() -> None:
92100
method = "post"
93101
request_origin = "ui"
94102
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
95-
rate_limits_config, route, method, request_origin
103+
rate_limits_config, user_type, route, method, request_origin
96104
)
97105
exp_rate_limits = ["1/second"]
98106
assert rate_limits == exp_rate_limits
@@ -101,48 +109,52 @@ def test_get_rate_limits_defaulted_default_value() -> None:
101109
method = "post"
102110
request_origin = "ui"
103111
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
104-
rate_limits_config, route, method, request_origin
112+
rate_limits_config, user_type, route, method, request_origin
105113
)
106114
exp_rate_limits = ["1/second"]
107115
assert rate_limits == exp_rate_limits
108116

109117

110118
def test_get_rate_limits_defaulted_route_param_actual_value() -> None:
111119
rate_limits = {
112-
"/processes/{process_id}/execution": {
113-
"test_process_id": {"post": {"api": ["2/second"]}}
114-
},
115-
"default": {"post": {"ui": ["1/second"]}},
120+
"auth": {
121+
"/processes/{process_id}/execution": {
122+
"test_process_id": {"post": {"api": ["2/second"]}}
123+
},
124+
"default": {"post": {"ui": ["1/second"]}},
125+
}
116126
}
117127
rate_limits_config = config.RateLimitsConfig(**rate_limits)
118-
128+
user_type = "auth"
119129
route = "processes_processid_execution"
120130
method = "post"
121131
request_origin = "api"
122132
route_param = "test_process_id"
123133
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
124-
rate_limits_config, route, method, request_origin, route_param
134+
rate_limits_config, user_type, route, method, request_origin, route_param
125135
)
126136
exp_rate_limits = ["2/second"]
127137
assert rate_limits == exp_rate_limits
128138

129139

130140
def test_get_rate_limits_defaulted_route_param_default_value() -> None:
131141
rate_limits = {
132-
"/processes/{process_id}/execution": {
133-
"test_process_id": {"post": {"api": ["2/second"]}},
134-
"default": {"post": {"api": ["1/second"]}},
135-
},
136-
"default": {"post": {"ui": ["1/minute"]}},
142+
"auth": {
143+
"/processes/{process_id}/execution": {
144+
"test_process_id": {"post": {"api": ["2/second"]}},
145+
"default": {"post": {"api": ["1/second"]}},
146+
},
147+
"default": {"post": {"ui": ["1/minute"]}},
148+
}
137149
}
138150
rate_limits_config = config.RateLimitsConfig(**rate_limits)
139-
151+
user_type = "auth"
140152
route = "processes_processid_execution"
141153
method = "post"
142154
request_origin = "api"
143155
route_param = "missing_test_process_id"
144156
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
145-
rate_limits_config, route, method, request_origin, route_param
157+
rate_limits_config, user_type, route, method, request_origin, route_param
146158
)
147159
exp_rate_limits = ["1/second"]
148160
assert rate_limits == exp_rate_limits
@@ -152,21 +164,21 @@ def test_get_rate_limits_defaulted_route_param_default_value() -> None:
152164
request_origin = "ui"
153165
route_param = "missing_test_process_id"
154166
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
155-
rate_limits_config, route, method, request_origin, route_param
167+
rate_limits_config, user_type, route, method, request_origin, route_param
156168
)
157169
exp_rate_limits = ["1/minute"]
158170
assert rate_limits == exp_rate_limits
159171

160172

161173
def test_get_rate_limits_undefined() -> None:
162-
rate_limits = {"/jobs": {"get": {"api": ["2/second"]}}}
174+
rate_limits = {"auth": {"/jobs": {"get": {"api": ["2/second"]}}}}
163175
rate_limits_config = config.RateLimitsConfig.model_validate(rate_limits)
164-
176+
user_type = "auth"
165177
route = "jobs"
166178
method = "get"
167179
request_origin = "ui"
168180
rate_limits = cads_processing_api_service.limits.get_rate_limits(
169-
rate_limits_config, route, method, request_origin
181+
rate_limits_config, user_type, route, method, request_origin
170182
)
171183
exp_rate_limits = []
172184
assert rate_limits == exp_rate_limits
@@ -175,7 +187,7 @@ def test_get_rate_limits_undefined() -> None:
175187
method = "post"
176188
request_origin = "ui"
177189
rate_limits = cads_processing_api_service.limits.get_rate_limits(
178-
rate_limits_config, route, method, request_origin
190+
rate_limits_config, user_type, route, method, request_origin
179191
)
180192
exp_rate_limits = []
181193
assert rate_limits == exp_rate_limits
@@ -184,7 +196,7 @@ def test_get_rate_limits_undefined() -> None:
184196
method = "get"
185197
request_origin = "ui"
186198
rate_limits = cads_processing_api_service.limits.get_rate_limits(
187-
rate_limits_config, route, method, request_origin
199+
rate_limits_config, user_type, route, method, request_origin
188200
)
189201
exp_rate_limits = []
190202
assert rate_limits == exp_rate_limits

0 commit comments

Comments
 (0)