diff --git a/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLogin.java b/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLogin.java
index 64c24b74689..781cf1d689d 100644
--- a/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLogin.java
+++ b/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLogin.java
@@ -31,15 +31,16 @@
import java.security.SecureRandom;
import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
/**
* A persistent login feature ("remember me") similar to the implementation in Spring Security,
* which is based on Improved Persistent Login Cookie
* Best Practice .
- *
+ *
* The one-time tokens generated by this class are purely random and do not contain a user name or other information. For security reasons,
* tokens and user information are not stored anywhere, so if the database is shut down, registered tokens will be gone.
- *
+ *
* The one-time token approach has the negative effect that requests need to be made in sequence, which is sometimes difficult if an app uses
* concurrent AJAX requests. Unfortunately, this is the price we have to pay for a sufficiently secure protection against
* cookie stealing attacks.
@@ -60,11 +61,11 @@ public static PersistentLogin getInstance() {
public final static int DEFAULT_TOKEN_LENGTH = 16;
- public final static int INVALIDATION_TIMEOUT = 20000;
+ public final static int INVALIDATION_TIMEOUT = 10000;
- private Map seriesMap = Collections.synchronizedMap(new HashMap<>());
+ private final Map seriesMap = new ConcurrentHashMap<>();
- private SecureRandom random;
+ private final SecureRandom random;
public PersistentLogin() {
random = new SecureRandom();
@@ -73,10 +74,10 @@ public PersistentLogin() {
/**
* Register the user and generate a first login token which will be valid for the next
* call to {@link #lookup(String)}.
- *
+ *
* The generated token will have the format base64(series-hash):base64(token-hash).
*
- * @param user the user name
+ * @param user the username
* @param password the password
* @param timeToLive timeout of the token
* @return a first login token
@@ -154,17 +155,18 @@ private String generateToken() {
public class LoginDetails {
- private String userName;
- private String password;
private String token;
- private String series;
- private long expires;
- private DurationValue timeToLive;
+
+ private final String userName;
+ private final String password;
+ private final String series;
+ private final long expires;
+ private final DurationValue timeToLive;
// disable sequential token checking by default
- private boolean seqBehavior = false;
+ private final boolean seqBehavior = false;
- private Map invalidatedTokens = new HashMap<>();
+ private final Map invalidatedTokens = new HashMap<>();
public LoginDetails(String user, String password, DurationValue timeToLive, long expires) {
this.userName = user;
@@ -176,19 +178,19 @@ public LoginDetails(String user, String password, DurationValue timeToLive, long
}
public String getToken() {
- return this.token;
+ return token;
}
public String getSeries() {
- return this.series;
+ return series;
}
public String getUser() {
- return this.userName;
+ return userName;
}
public String getPassword() {
- return this.password;
+ return password;
}
public DurationValue getTimeToLive() {
@@ -197,13 +199,15 @@ public DurationValue getTimeToLive() {
public boolean checkAndUpdateToken(String token) {
if (this.token.equals(token)) {
+ timeoutCheck();
update();
return true;
}
// check map of invalidating tokens
- Long timeout = invalidatedTokens.get(token);
- if (timeout == null)
+ final Long timeout = invalidatedTokens.get(token);
+ if (timeout == null) {
return false;
+ }
// timed out: remove
if (System.currentTimeMillis() > timeout) {
invalidatedTokens.remove(token);
@@ -213,23 +217,21 @@ public boolean checkAndUpdateToken(String token) {
return true;
}
- public String update() {
- timeoutCheck();
- // leave a small time window until previous token is deleted
+ public void update() {
+ // leave a small time-window until previous token is deleted
// to allow for concurrent requests
- invalidatedTokens.put(this.token, System.currentTimeMillis() + INVALIDATION_TIMEOUT);
- this.token = generateToken();
- return this.token;
+ invalidatedTokens.put(token, System.currentTimeMillis() + INVALIDATION_TIMEOUT);
+ token = generateToken();
}
private void timeoutCheck() {
- long now = System.currentTimeMillis();
+ final long now = System.currentTimeMillis();
invalidatedTokens.entrySet().removeIf(entry -> entry.getValue() < now);
}
@Override
public String toString() {
- return this.series + ":" + this.token;
+ return series + ":" + token;
}
}
}
diff --git a/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLoginFunctions.java b/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLoginFunctions.java
index 7a2f7aad620..ec988467c51 100644
--- a/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLoginFunctions.java
+++ b/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLoginFunctions.java
@@ -27,17 +27,36 @@
import org.exist.security.SecurityManager;
import org.exist.security.Subject;
import org.exist.storage.BrokerPool;
-import org.exist.xquery.*;
-import org.exist.xquery.value.*;
+import org.exist.xquery.AnalyzeContextInfo;
+import org.exist.xquery.Cardinality;
+import org.exist.xquery.ErrorCodes;
+import org.exist.xquery.Function;
+import org.exist.xquery.FunctionSignature;
+import org.exist.xquery.UserSwitchingBasicFunction;
+import org.exist.xquery.XPathException;
+import org.exist.xquery.XQueryContext;
+import org.exist.xquery.value.DurationValue;
+import org.exist.xquery.value.FunctionParameterSequenceType;
+import org.exist.xquery.value.FunctionReference;
+import org.exist.xquery.value.FunctionReturnSequenceType;
+import org.exist.xquery.value.Sequence;
+import org.exist.xquery.value.SequenceType;
+import org.exist.xquery.value.StringValue;
+import org.exist.xquery.value.Type;
+
+import javax.annotation.Nullable;
+
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.Map;
/**
* Functions to access the persistent login module.
*/
public class PersistentLoginFunctions extends UserSwitchingBasicFunction {
-
- public final static FunctionSignature signatures[] = {
+ public final static FunctionSignature[] signatures = {
new FunctionSignature(
- new QName("register", PersistentLoginModule.NAMESPACE, PersistentLoginModule.PREFIX),
+ PersistentLoginFn.REGISTER.getQName(),
"Try to log in the user and create a one-time login token. The token can be stored to a cookie and used to log in " +
"(via the login function) as the same user without " +
"providing credentials. However, for security reasons the token will be valid only for " +
@@ -55,7 +74,7 @@ public class PersistentLoginFunctions extends UserSwitchingBasicFunction {
new FunctionReturnSequenceType(Type.ITEM, Cardinality.ZERO_OR_MORE, "result of the callback function or the empty sequence")
),
new FunctionSignature(
- new QName("login", PersistentLoginModule.NAMESPACE, PersistentLoginModule.PREFIX),
+ PersistentLoginFn.LOGIN.getQName(),
"Try to log in the user based on the supplied token. If the login succeeds, the provided callback function " +
"is called with 4 arguments: $token as xs:string, $user as xs:string, $password as xs:string, $timeToLive as duration. " +
"$token will be a new token which can be used for the next request. The old token is deleted.",
@@ -67,7 +86,7 @@ public class PersistentLoginFunctions extends UserSwitchingBasicFunction {
new FunctionReturnSequenceType(Type.ITEM, Cardinality.ZERO_OR_MORE, "result of the callback function or the empty sequence")
),
new FunctionSignature(
- new QName("invalidate", PersistentLoginModule.NAMESPACE, PersistentLoginModule.PREFIX),
+ PersistentLoginFn.INVALIDATE.getQName(),
"Invalidate the supplied one-time token, so it can no longer be used to log in.",
new SequenceType[]{
new FunctionParameterSequenceType("token", Type.STRING, Cardinality.EXACTLY_ONE, "a valid one-time token")
@@ -75,87 +94,71 @@ public class PersistentLoginFunctions extends UserSwitchingBasicFunction {
new FunctionReturnSequenceType(Type.EMPTY, Cardinality.EXACTLY_ONE, "empty sequence")
)
};
-
private AnalyzeContextInfo cachedContextInfo;
public PersistentLoginFunctions(final XQueryContext context, final FunctionSignature signature) {
super(context, signature);
}
+ private static Sequence invalidate(Sequence[] args) throws XPathException {
+ PersistentLogin.getInstance().invalidate(args[0].getStringValue());
+ return Sequence.EMPTY_SEQUENCE;
+ }
+
@Override
public void analyze(final AnalyzeContextInfo contextInfo) throws XPathException {
super.analyze(contextInfo);
- this.cachedContextInfo = new AnalyzeContextInfo(contextInfo);
+ cachedContextInfo = new AnalyzeContextInfo(contextInfo);
}
@Override
public Sequence eval(final Sequence[] args, final Sequence contextSequence) throws XPathException {
- if (isCalledAs("register")) {
- final String user = args[0].getStringValue();
- final String pass;
- if (!args[1].isEmpty()) {
- pass = args[1].getStringValue();
- } else {
- pass = null;
- }
- final DurationValue timeToLive = (DurationValue) args[2].itemAt(0);
- final FunctionReference callback;
- if (!args[3].isEmpty()) {
- callback = (FunctionReference) args[3].itemAt(0);
- } else {
- callback = null;
- }
- try {
- return register(user, pass, timeToLive, callback);
- } finally {
- if (callback != null) {
- callback.close();
- }
- }
- } else if (isCalledAs("login")) {
- final String token = args[0].getStringValue();
- final FunctionReference callback;
- if (!args[1].isEmpty()) {
- callback = (FunctionReference) args[1].itemAt(0);
- } else {
- callback = null;
- }
- try {
- return authenticate(token, callback);
- } finally {
- if (callback != null) {
- callback.close();
- }
- }
- } else {
- PersistentLogin.getInstance().invalidate(args[0].getStringValue());
- return Sequence.EMPTY_SEQUENCE;
+ switch (PersistentLoginFn.get(this)) {
+ case REGISTER:
+ return register(args);
+ case LOGIN:
+ return login(args);
+ case INVALIDATE:
+ return invalidate(args);
+ default:
+ throw new XPathException(this, ErrorCodes.ERROR, "Unknown function: " + getName());
}
}
- private Sequence register(final String user, final String pass, final DurationValue timeToLive, final FunctionReference callback) throws XPathException {
- if (login(user, pass)) {
- final PersistentLogin.LoginDetails details = PersistentLogin.getInstance().register(user, pass, timeToLive);
- return callback(callback, null, details);
+ private Sequence register(Sequence[] args) throws XPathException {
+ final String user = args[0].getStringValue();
+
+ final String pass;
+ if (args[1].isEmpty()) {
+ pass = null;
+ } else {
+ pass = args[1].getStringValue();
}
- return Sequence.EMPTY_SEQUENCE;
- }
- private Sequence authenticate(final String token, final FunctionReference callback) throws XPathException {
- final PersistentLogin.LoginDetails data = PersistentLogin.getInstance().lookup(token);
+ final DurationValue timeToLive = (DurationValue) args[2].itemAt(0);
- if (data == null) {
- return Sequence.EMPTY_SEQUENCE;
+ try (FunctionReference callback = getCallBack(args[3])) {
+ if (unauthenticated(user, pass)) {
+ return Sequence.EMPTY_SEQUENCE;
+ }
+ final PersistentLogin.LoginDetails details = PersistentLogin.getInstance().register(user, pass, timeToLive);
+ return call(callback, null, details);
}
+ }
- if (login(data.getUser(), data.getPassword())) {
- return callback(callback, token, data);
- }
+ private Sequence login(Sequence[] args) throws XPathException {
+ final String token = args[0].getStringValue();
+ try (FunctionReference callback = getCallBack(args[1])) {
+ final PersistentLogin.LoginDetails data = PersistentLogin.getInstance().lookup(token);
- return Sequence.EMPTY_SEQUENCE;
+ if (data == null || unauthenticated(data.getUser(), data.getPassword())) {
+ return Sequence.EMPTY_SEQUENCE;
+ }
+ return call(callback, token, data);
+ }
}
- private boolean login(final String user, final String pass) throws XPathException {
+ private boolean unauthenticated(final String user, final String pass) {
try {
final SecurityManager sm = BrokerPool.getInstance().getSecurityManager();
final Subject subject = sm.authenticate(user, pass);
@@ -163,13 +166,14 @@ private boolean login(final String user, final String pass) throws XPathExceptio
//switch the user of the current broker
switchUser(subject);
- return true;
- } catch (final AuthenticationException | EXistException e) {
return false;
+ } catch (final AuthenticationException | EXistException e) {
+ return true;
}
}
- private Sequence callback(final FunctionReference func, final String oldToken, final PersistentLogin.LoginDetails details) throws XPathException {
+ private Sequence call(@Nullable final FunctionReference func, final String oldToken, final PersistentLogin.LoginDetails details) throws XPathException {
+ if (func == null) return Sequence.EMPTY_SEQUENCE;
final Sequence[] args = new Sequence[4];
final String newToken = details.toString();
@@ -185,4 +189,39 @@ private Sequence callback(final FunctionReference func, final String oldToken, f
func.analyze(cachedContextInfo);
return func.evalFunction(null, null, args);
}
+
+ private @Nullable FunctionReference getCallBack(final Sequence arg) {
+ if (arg.isEmpty()) {
+ return null;
+ }
+ return (FunctionReference) arg.itemAt(0);
+ }
+
+ private enum PersistentLoginFn {
+ REGISTER("register"),
+ LOGIN("login"),
+ INVALIDATE("invalidate");
+
+ final static Map lookup = new HashMap<>();
+
+ static {
+ for (PersistentLoginFn persistentLoginFn : EnumSet.allOf(PersistentLoginFn.class)) {
+ lookup.put(persistentLoginFn.getQName(), persistentLoginFn);
+ }
+ }
+
+ private final QName qname;
+
+ PersistentLoginFn(String name) {
+ qname = new QName(name, PersistentLoginModule.NAMESPACE, PersistentLoginModule.PREFIX);
+ }
+
+ static PersistentLoginFn get(Function f) {
+ return lookup.get(f.getName());
+ }
+
+ public QName getQName() {
+ return qname;
+ }
+ }
}