diff --git a/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLogin.java b/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLogin.java index 64c24b74689..781cf1d689d 100644 --- a/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLogin.java +++ b/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLogin.java @@ -31,15 +31,16 @@ import java.security.SecureRandom; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; /** * A persistent login feature ("remember me") similar to the implementation in Spring Security, * which is based on Improved Persistent Login Cookie * Best Practice . - * + *

* The one-time tokens generated by this class are purely random and do not contain a user name or other information. For security reasons, * tokens and user information are not stored anywhere, so if the database is shut down, registered tokens will be gone. - * + *

* The one-time token approach has the negative effect that requests need to be made in sequence, which is sometimes difficult if an app uses * concurrent AJAX requests. Unfortunately, this is the price we have to pay for a sufficiently secure protection against * cookie stealing attacks. @@ -60,11 +61,11 @@ public static PersistentLogin getInstance() { public final static int DEFAULT_TOKEN_LENGTH = 16; - public final static int INVALIDATION_TIMEOUT = 20000; + public final static int INVALIDATION_TIMEOUT = 10000; - private Map seriesMap = Collections.synchronizedMap(new HashMap<>()); + private final Map seriesMap = new ConcurrentHashMap<>(); - private SecureRandom random; + private final SecureRandom random; public PersistentLogin() { random = new SecureRandom(); @@ -73,10 +74,10 @@ public PersistentLogin() { /** * Register the user and generate a first login token which will be valid for the next * call to {@link #lookup(String)}. - * + *

* The generated token will have the format base64(series-hash):base64(token-hash). * - * @param user the user name + * @param user the username * @param password the password * @param timeToLive timeout of the token * @return a first login token @@ -154,17 +155,18 @@ private String generateToken() { public class LoginDetails { - private String userName; - private String password; private String token; - private String series; - private long expires; - private DurationValue timeToLive; + + private final String userName; + private final String password; + private final String series; + private final long expires; + private final DurationValue timeToLive; // disable sequential token checking by default - private boolean seqBehavior = false; + private final boolean seqBehavior = false; - private Map invalidatedTokens = new HashMap<>(); + private final Map invalidatedTokens = new HashMap<>(); public LoginDetails(String user, String password, DurationValue timeToLive, long expires) { this.userName = user; @@ -176,19 +178,19 @@ public LoginDetails(String user, String password, DurationValue timeToLive, long } public String getToken() { - return this.token; + return token; } public String getSeries() { - return this.series; + return series; } public String getUser() { - return this.userName; + return userName; } public String getPassword() { - return this.password; + return password; } public DurationValue getTimeToLive() { @@ -197,13 +199,15 @@ public DurationValue getTimeToLive() { public boolean checkAndUpdateToken(String token) { if (this.token.equals(token)) { + timeoutCheck(); update(); return true; } // check map of invalidating tokens - Long timeout = invalidatedTokens.get(token); - if (timeout == null) + final Long timeout = invalidatedTokens.get(token); + if (timeout == null) { return false; + } // timed out: remove if (System.currentTimeMillis() > timeout) { invalidatedTokens.remove(token); @@ -213,23 +217,21 @@ public boolean checkAndUpdateToken(String token) { return true; } - public String update() { - timeoutCheck(); - // leave a small time window until previous token is deleted + public void update() { + // leave a small time-window until previous token is deleted // to allow for concurrent requests - invalidatedTokens.put(this.token, System.currentTimeMillis() + INVALIDATION_TIMEOUT); - this.token = generateToken(); - return this.token; + invalidatedTokens.put(token, System.currentTimeMillis() + INVALIDATION_TIMEOUT); + token = generateToken(); } private void timeoutCheck() { - long now = System.currentTimeMillis(); + final long now = System.currentTimeMillis(); invalidatedTokens.entrySet().removeIf(entry -> entry.getValue() < now); } @Override public String toString() { - return this.series + ":" + this.token; + return series + ":" + token; } } } diff --git a/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLoginFunctions.java b/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLoginFunctions.java index 7a2f7aad620..ec988467c51 100644 --- a/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLoginFunctions.java +++ b/extensions/modules/persistentlogin/src/main/java/org/exist/xquery/modules/persistentlogin/PersistentLoginFunctions.java @@ -27,17 +27,36 @@ import org.exist.security.SecurityManager; import org.exist.security.Subject; import org.exist.storage.BrokerPool; -import org.exist.xquery.*; -import org.exist.xquery.value.*; +import org.exist.xquery.AnalyzeContextInfo; +import org.exist.xquery.Cardinality; +import org.exist.xquery.ErrorCodes; +import org.exist.xquery.Function; +import org.exist.xquery.FunctionSignature; +import org.exist.xquery.UserSwitchingBasicFunction; +import org.exist.xquery.XPathException; +import org.exist.xquery.XQueryContext; +import org.exist.xquery.value.DurationValue; +import org.exist.xquery.value.FunctionParameterSequenceType; +import org.exist.xquery.value.FunctionReference; +import org.exist.xquery.value.FunctionReturnSequenceType; +import org.exist.xquery.value.Sequence; +import org.exist.xquery.value.SequenceType; +import org.exist.xquery.value.StringValue; +import org.exist.xquery.value.Type; + +import javax.annotation.Nullable; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; /** * Functions to access the persistent login module. */ public class PersistentLoginFunctions extends UserSwitchingBasicFunction { - - public final static FunctionSignature signatures[] = { + public final static FunctionSignature[] signatures = { new FunctionSignature( - new QName("register", PersistentLoginModule.NAMESPACE, PersistentLoginModule.PREFIX), + PersistentLoginFn.REGISTER.getQName(), "Try to log in the user and create a one-time login token. The token can be stored to a cookie and used to log in " + "(via the login function) as the same user without " + "providing credentials. However, for security reasons the token will be valid only for " + @@ -55,7 +74,7 @@ public class PersistentLoginFunctions extends UserSwitchingBasicFunction { new FunctionReturnSequenceType(Type.ITEM, Cardinality.ZERO_OR_MORE, "result of the callback function or the empty sequence") ), new FunctionSignature( - new QName("login", PersistentLoginModule.NAMESPACE, PersistentLoginModule.PREFIX), + PersistentLoginFn.LOGIN.getQName(), "Try to log in the user based on the supplied token. If the login succeeds, the provided callback function " + "is called with 4 arguments: $token as xs:string, $user as xs:string, $password as xs:string, $timeToLive as duration. " + "$token will be a new token which can be used for the next request. The old token is deleted.", @@ -67,7 +86,7 @@ public class PersistentLoginFunctions extends UserSwitchingBasicFunction { new FunctionReturnSequenceType(Type.ITEM, Cardinality.ZERO_OR_MORE, "result of the callback function or the empty sequence") ), new FunctionSignature( - new QName("invalidate", PersistentLoginModule.NAMESPACE, PersistentLoginModule.PREFIX), + PersistentLoginFn.INVALIDATE.getQName(), "Invalidate the supplied one-time token, so it can no longer be used to log in.", new SequenceType[]{ new FunctionParameterSequenceType("token", Type.STRING, Cardinality.EXACTLY_ONE, "a valid one-time token") @@ -75,87 +94,71 @@ public class PersistentLoginFunctions extends UserSwitchingBasicFunction { new FunctionReturnSequenceType(Type.EMPTY, Cardinality.EXACTLY_ONE, "empty sequence") ) }; - private AnalyzeContextInfo cachedContextInfo; public PersistentLoginFunctions(final XQueryContext context, final FunctionSignature signature) { super(context, signature); } + private static Sequence invalidate(Sequence[] args) throws XPathException { + PersistentLogin.getInstance().invalidate(args[0].getStringValue()); + return Sequence.EMPTY_SEQUENCE; + } + @Override public void analyze(final AnalyzeContextInfo contextInfo) throws XPathException { super.analyze(contextInfo); - this.cachedContextInfo = new AnalyzeContextInfo(contextInfo); + cachedContextInfo = new AnalyzeContextInfo(contextInfo); } @Override public Sequence eval(final Sequence[] args, final Sequence contextSequence) throws XPathException { - if (isCalledAs("register")) { - final String user = args[0].getStringValue(); - final String pass; - if (!args[1].isEmpty()) { - pass = args[1].getStringValue(); - } else { - pass = null; - } - final DurationValue timeToLive = (DurationValue) args[2].itemAt(0); - final FunctionReference callback; - if (!args[3].isEmpty()) { - callback = (FunctionReference) args[3].itemAt(0); - } else { - callback = null; - } - try { - return register(user, pass, timeToLive, callback); - } finally { - if (callback != null) { - callback.close(); - } - } - } else if (isCalledAs("login")) { - final String token = args[0].getStringValue(); - final FunctionReference callback; - if (!args[1].isEmpty()) { - callback = (FunctionReference) args[1].itemAt(0); - } else { - callback = null; - } - try { - return authenticate(token, callback); - } finally { - if (callback != null) { - callback.close(); - } - } - } else { - PersistentLogin.getInstance().invalidate(args[0].getStringValue()); - return Sequence.EMPTY_SEQUENCE; + switch (PersistentLoginFn.get(this)) { + case REGISTER: + return register(args); + case LOGIN: + return login(args); + case INVALIDATE: + return invalidate(args); + default: + throw new XPathException(this, ErrorCodes.ERROR, "Unknown function: " + getName()); } } - private Sequence register(final String user, final String pass, final DurationValue timeToLive, final FunctionReference callback) throws XPathException { - if (login(user, pass)) { - final PersistentLogin.LoginDetails details = PersistentLogin.getInstance().register(user, pass, timeToLive); - return callback(callback, null, details); + private Sequence register(Sequence[] args) throws XPathException { + final String user = args[0].getStringValue(); + + final String pass; + if (args[1].isEmpty()) { + pass = null; + } else { + pass = args[1].getStringValue(); } - return Sequence.EMPTY_SEQUENCE; - } - private Sequence authenticate(final String token, final FunctionReference callback) throws XPathException { - final PersistentLogin.LoginDetails data = PersistentLogin.getInstance().lookup(token); + final DurationValue timeToLive = (DurationValue) args[2].itemAt(0); - if (data == null) { - return Sequence.EMPTY_SEQUENCE; + try (FunctionReference callback = getCallBack(args[3])) { + if (unauthenticated(user, pass)) { + return Sequence.EMPTY_SEQUENCE; + } + final PersistentLogin.LoginDetails details = PersistentLogin.getInstance().register(user, pass, timeToLive); + return call(callback, null, details); } + } - if (login(data.getUser(), data.getPassword())) { - return callback(callback, token, data); - } + private Sequence login(Sequence[] args) throws XPathException { + final String token = args[0].getStringValue(); + try (FunctionReference callback = getCallBack(args[1])) { + final PersistentLogin.LoginDetails data = PersistentLogin.getInstance().lookup(token); - return Sequence.EMPTY_SEQUENCE; + if (data == null || unauthenticated(data.getUser(), data.getPassword())) { + return Sequence.EMPTY_SEQUENCE; + } + return call(callback, token, data); + } } - private boolean login(final String user, final String pass) throws XPathException { + private boolean unauthenticated(final String user, final String pass) { try { final SecurityManager sm = BrokerPool.getInstance().getSecurityManager(); final Subject subject = sm.authenticate(user, pass); @@ -163,13 +166,14 @@ private boolean login(final String user, final String pass) throws XPathExceptio //switch the user of the current broker switchUser(subject); - return true; - } catch (final AuthenticationException | EXistException e) { return false; + } catch (final AuthenticationException | EXistException e) { + return true; } } - private Sequence callback(final FunctionReference func, final String oldToken, final PersistentLogin.LoginDetails details) throws XPathException { + private Sequence call(@Nullable final FunctionReference func, final String oldToken, final PersistentLogin.LoginDetails details) throws XPathException { + if (func == null) return Sequence.EMPTY_SEQUENCE; final Sequence[] args = new Sequence[4]; final String newToken = details.toString(); @@ -185,4 +189,39 @@ private Sequence callback(final FunctionReference func, final String oldToken, f func.analyze(cachedContextInfo); return func.evalFunction(null, null, args); } + + private @Nullable FunctionReference getCallBack(final Sequence arg) { + if (arg.isEmpty()) { + return null; + } + return (FunctionReference) arg.itemAt(0); + } + + private enum PersistentLoginFn { + REGISTER("register"), + LOGIN("login"), + INVALIDATE("invalidate"); + + final static Map lookup = new HashMap<>(); + + static { + for (PersistentLoginFn persistentLoginFn : EnumSet.allOf(PersistentLoginFn.class)) { + lookup.put(persistentLoginFn.getQName(), persistentLoginFn); + } + } + + private final QName qname; + + PersistentLoginFn(String name) { + qname = new QName(name, PersistentLoginModule.NAMESPACE, PersistentLoginModule.PREFIX); + } + + static PersistentLoginFn get(Function f) { + return lookup.get(f.getName()); + } + + public QName getQName() { + return qname; + } + } }