Skip to content

Commit 4454fcd

Browse files
authored
Merge pull request #7062 from aldbr/rel-v8r0_FIX_TokenPilotSubmission
[8.0] fix: interacting with CEs using tokens
2 parents 2113fec + 6c2b4f8 commit 4454fcd

File tree

16 files changed

+355
-87
lines changed

16 files changed

+355
-87
lines changed

src/DIRAC/Core/Utilities/Grid.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from DIRAC.Core.Utilities.Subprocess import systemCall, shellCall
1212

1313

14-
def executeGridCommand(proxy, cmd, gridEnvScript=None, gridEnvDict=None):
14+
def executeGridCommand(cmd, gridEnvScript=None, gridEnvDict=None):
1515
"""
1616
Execute cmd tuple after sourcing GridEnv
1717
"""
@@ -37,22 +37,6 @@ def executeGridCommand(proxy, cmd, gridEnvScript=None, gridEnvDict=None):
3737
else:
3838
gridEnv = currentEnv
3939

40-
if not proxy:
41-
res = getProxyInfo()
42-
if not res["OK"]:
43-
return res
44-
gridEnv["X509_USER_PROXY"] = res["Value"]["path"]
45-
elif isinstance(proxy, str):
46-
if os.path.exists(proxy):
47-
gridEnv["X509_USER_PROXY"] = proxy
48-
else:
49-
return S_ERROR("Can not treat proxy passed as a string")
50-
else:
51-
ret = gProxyManager.dumpProxyToFile(proxy)
52-
if not ret["OK"]:
53-
return ret
54-
gridEnv["X509_USER_PROXY"] = ret["Value"]
55-
5640
if gridEnvDict:
5741
gridEnv.update(gridEnvDict)
5842

src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def getToken(
3434
self,
3535
username: str = None,
3636
userGroup: str = None,
37-
scope: str = None,
37+
scope: list[str] = None,
3838
audience: str = None,
3939
identityProvider: str = None,
4040
requiredTimeLeft: int = 0,

src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,20 @@ def initializeHandler(cls, *args):
6161
6262
:return: S_OK()/S_ERROR()
6363
"""
64+
# Cache containing tokens from scope requested by the client
65+
cls.__tokensCache = DictCache()
66+
67+
# The service plays an important OAuth 2.0 role, namely it is an Identity Provider client.
68+
# This allows you to manage tokens without the involvement of their owners.
69+
cls.idps = IdProviderFactory()
70+
6471
# Let's try to connect to the database
6572
try:
6673
cls.__tokenDB = TokenDB(parentLogger=cls.log)
6774
except Exception as e:
6875
cls.log.exception(e)
6976
return S_ERROR(f"Could not connect to the database {repr(e)}")
7077

71-
# Cache containing tokens from scope requested by the client
72-
cls.__tokensCache = DictCache()
73-
74-
# The service plays an important OAuth 2.0 role, namely it is an Identity Provider client.
75-
# This allows you to manage tokens without the involvement of their owners.
76-
cls.idps = IdProviderFactory()
7778
return S_OK()
7879

7980
def export_getUserTokensInfo(self):
@@ -185,7 +186,7 @@ def export_getToken(
185186
self,
186187
username: str = None,
187188
userGroup: str = None,
188-
scope: str = None,
189+
scope: list[str] = None,
189190
audience: str = None,
190191
identityProvider: str = None,
191192
requiredTimeLeft: int = 0,

src/DIRAC/FrameworkSystem/Utilities/TokenManagementUtilities.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def getCachedKey(
3030
idProviderClient,
3131
username: str = None,
3232
userGroup: str = None,
33-
scope: str = None,
33+
scope: list[str] = None,
3434
audience: str = None,
3535
):
3636
"""Build the key to potentially retrieve a cached token given the provided parameters.
@@ -53,7 +53,9 @@ def getCachedKey(
5353
if userGroup and (result := idProviderClient.getGroupScopes(userGroup)):
5454
# What scope correspond to the requested group?
5555
scope = list(set((scope or []) + result))
56-
scope = " ".join(scope)
56+
57+
if scope:
58+
scope = " ".join(sorted(scope))
5759

5860
return (subject, scope, audience, idProviderClient.name, idProviderClient.issuer)
5961

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
""" Test IdProvider Factory"""
2+
import pytest
3+
import time
4+
5+
from DIRAC import S_ERROR, S_OK
6+
from DIRAC.Core.Utilities.DictCache import DictCache
7+
from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token
8+
from DIRAC.FrameworkSystem.Utilities.TokenManagementUtilities import getCachedKey, getCachedToken
9+
from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider
10+
11+
12+
@pytest.mark.parametrize(
13+
"idProviderType, idProviderName, issuer, username, group, scope, audience, expectedValue",
14+
[
15+
# Only a client name: this is mandatory
16+
(OAuth2IdProvider, "IdPTest", "Issuer1", None, None, None, None, ("IdPTest", None, None, "IdPTest", "Issuer1")),
17+
(
18+
OAuth2IdProvider,
19+
"IdPTest2",
20+
"Issuer1",
21+
None,
22+
None,
23+
None,
24+
None,
25+
("IdPTest2", None, None, "IdPTest2", "Issuer1"),
26+
),
27+
(
28+
OAuth2IdProvider,
29+
"IdPTest2",
30+
"Issuer2",
31+
None,
32+
None,
33+
None,
34+
None,
35+
("IdPTest2", None, None, "IdPTest2", "Issuer2"),
36+
),
37+
# Client name and username
38+
(OAuth2IdProvider, "IdPTest", "Issuer1", "user", None, None, None, ("user", None, None, "IdPTest", "Issuer1")),
39+
# Client name and group (should not add any permission in scope)
40+
(
41+
OAuth2IdProvider,
42+
"IdPTest",
43+
"Issuer1",
44+
None,
45+
"group",
46+
None,
47+
None,
48+
("IdPTest", None, None, "IdPTest", "Issuer1"),
49+
),
50+
# Client name and scope
51+
(
52+
OAuth2IdProvider,
53+
"IdPTest",
54+
"Issuer1",
55+
None,
56+
None,
57+
["permission:1", "permission:2"],
58+
None,
59+
("IdPTest", "permission:1 permission:2", None, "IdPTest", "Issuer1"),
60+
),
61+
(
62+
OAuth2IdProvider,
63+
"IdPTest",
64+
"Issuer1",
65+
None,
66+
None,
67+
["permission:2", "permission:1"],
68+
None,
69+
("IdPTest", "permission:1 permission:2", None, "IdPTest", "Issuer1"),
70+
),
71+
# Client name and audience
72+
(
73+
OAuth2IdProvider,
74+
"IdPTest",
75+
"Issuer1",
76+
None,
77+
None,
78+
None,
79+
"CE1",
80+
("IdPTest", None, "CE1", "IdPTest", "Issuer1"),
81+
),
82+
# Client name, username, group
83+
(
84+
OAuth2IdProvider,
85+
"IdPTest",
86+
"Issuer1",
87+
"user",
88+
"group1",
89+
None,
90+
None,
91+
("user", None, None, "IdPTest", "Issuer1"),
92+
),
93+
# Client name, username, scope
94+
(
95+
OAuth2IdProvider,
96+
"IdPTest",
97+
"Issuer1",
98+
"user",
99+
None,
100+
["permission:1", "permission:2"],
101+
None,
102+
("user", "permission:1 permission:2", None, "IdPTest", "Issuer1"),
103+
),
104+
# Client name, username, audience
105+
(
106+
OAuth2IdProvider,
107+
"IdPTest",
108+
"Issuer1",
109+
"user",
110+
None,
111+
None,
112+
"CE1",
113+
("user", None, "CE1", "IdPTest", "Issuer1"),
114+
),
115+
# Client name, username, group, scope
116+
(
117+
OAuth2IdProvider,
118+
"IdPTest",
119+
"Issuer1",
120+
"user",
121+
"group1",
122+
["permission:1", "permission:2"],
123+
None,
124+
("user", "permission:1 permission:2", None, "IdPTest", "Issuer1"),
125+
),
126+
# Client name, username, group, audience
127+
(
128+
OAuth2IdProvider,
129+
"IdPTest",
130+
"Issuer1",
131+
"user",
132+
"group1",
133+
None,
134+
"CE1",
135+
("user", None, "CE1", "IdPTest", "Issuer1"),
136+
),
137+
# Client name, usergroup, scope, audience
138+
(
139+
OAuth2IdProvider,
140+
"IdPTest",
141+
"Issuer1",
142+
"user",
143+
"group1",
144+
["permission:1", "permission:2"],
145+
"CE1",
146+
("user", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"),
147+
),
148+
],
149+
)
150+
def test_getCachedKey(idProviderType, idProviderName, issuer, username, group, scope, audience, expectedValue):
151+
"""Test getCachedKey"""
152+
# Prepare IdP
153+
idProviderClient = idProviderType()
154+
idProviderClient.name = idProviderName
155+
idProviderClient.issuer = issuer
156+
157+
result = getCachedKey(idProviderClient, username, group, scope, audience)
158+
assert result == expectedValue
159+
160+
161+
@pytest.mark.parametrize(
162+
"cachedKey, requiredTimeLeft, expectedValue",
163+
[
164+
# Normal case
165+
(("IdPTest", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"), 0, S_OK()),
166+
# Empty cachedKey
167+
((), 0, S_ERROR("The key does not exist")),
168+
# Wrong cachedKey
169+
(("IdPTest", "permission:1", "CE1", "IdPTest", "Issuer1"), 0, S_ERROR("The key does not exist")),
170+
# Expired token (650 > 150)
171+
(
172+
("IdPTest", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"),
173+
650,
174+
S_ERROR("Token found but expired"),
175+
),
176+
# Expired cachedKey (1500 > 1200)
177+
(
178+
("IdPTest", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"),
179+
1500,
180+
S_ERROR("The key does not exist"),
181+
),
182+
],
183+
)
184+
def test_getCachedToken(cachedKey, requiredTimeLeft, expectedValue):
185+
"""Test getCachedToken"""
186+
# Prepare cachedToken dictionary
187+
cachedTokens = DictCache()
188+
currentTime = time.time()
189+
token = {
190+
"sub": "0001234",
191+
"aud": "CE1",
192+
"nbf": currentTime - 150,
193+
"scope": "permission:1 permission:2",
194+
"iss": "Issuer1",
195+
"exp": currentTime + 150,
196+
"iat": currentTime - 150,
197+
"jti": "000001234",
198+
"client_id": "0001234",
199+
}
200+
tokenKey = ("IdPTest", "permission:1 permission:2", token["aud"], "IdPTest", token["iss"])
201+
cachedTokens.add(tokenKey, 1200, OAuth2Token(token))
202+
203+
# Try to get the token from the cache
204+
result = getCachedToken(cachedTokens, cachedKey, requiredTimeLeft)
205+
assert result["OK"] == expectedValue["OK"]
206+
if result["OK"]:
207+
resultToken = result["Value"]
208+
assert resultToken["sub"] == token["sub"]
209+
assert resultToken["scope"] == token["scope"]
210+
else:
211+
assert result["Message"] == expectedValue["Message"]

src/DIRAC/Resources/Computing/AREXComputingElement.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def _reset(self):
6565

6666
# Get options from the ceParameters dictionary
6767
self.port = self.ceParameters.get("Port", self.port)
68+
self.audienceName = f"https://{self.ceName}:{self.port}"
69+
6870
self.restVersion = self.ceParameters.get("RESTVersion", self.restVersion)
6971

7072
self.proxyTimeLeftBeforeRenewal = self.ceParameters.get(

src/DIRAC/Resources/Computing/ComputingElement.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,21 @@ def __init__(self, ceName):
7171
self.log = gLogger.getSubLogger(ceName)
7272
self.ceName = ceName
7373
self.ceParameters = {}
74-
self.proxy = ""
75-
self.token = None
76-
self.valid = None
7774
self.mandatoryParameters = []
78-
self.batchSystem = None
79-
self.taskResults = {}
75+
76+
# Token audience
77+
# None by default, it needs to be redefined in subclasses
78+
self.audienceName = None
79+
self.token = None
80+
81+
self.proxy = ""
8082
self.minProxyTime = gConfig.getValue("/Registry/MinProxyLifeTime", 10800) # secs
8183
self.defaultProxyTime = gConfig.getValue("/Registry/DefaultProxyLifeTime", 43200) # secs
8284
self.proxyCheckPeriod = gConfig.getValue("/Registry/ProxyCheckingPeriod", 3600) # secs
85+
self.valid = None
86+
87+
self.batchSystem = None
88+
self.taskResults = {}
8389

8490
clsName = self.__class__.__name__
8591
if clsName.endswith("ComputingElement"):

0 commit comments

Comments
 (0)