Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 160 additions & 24 deletions java/src/org/openqa/selenium/grid/sessionmap/local/LocalSessionMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@
import static org.openqa.selenium.remote.RemoteTags.SESSION_ID;
import static org.openqa.selenium.remote.RemoteTags.SESSION_ID_EVENT;

import java.util.List;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.openqa.selenium.NoSuchSessionException;
import org.openqa.selenium.events.Event;
import org.openqa.selenium.events.EventBus;
import org.openqa.selenium.grid.config.Config;
import org.openqa.selenium.grid.data.NodeRemovedEvent;
Expand All @@ -48,7 +53,7 @@ public class LocalSessionMap extends SessionMap {
private static final Logger LOG = Logger.getLogger(LocalSessionMap.class.getName());

private final EventBus bus;
private final ConcurrentMap<SessionId, Session> knownSessions = new ConcurrentHashMap<>();
private final IndexedSessionMap knownSessions = new IndexedSessionMap();

public LocalSessionMap(Tracer tracer, EventBus bus) {
super(tracer);
Expand All @@ -59,23 +64,14 @@ public LocalSessionMap(Tracer tracer, EventBus bus) {

bus.addListener(
NodeRemovedEvent.listener(
nodeStatus ->
nodeStatus.getSlots().stream()
.filter(slot -> slot.getSession() != null)
.map(slot -> slot.getSession().getId())
.forEach(this::remove)));
nodeStatus -> {
batchRemoveByUri(nodeStatus.getExternalUri(), NodeRemovedEvent.class);
}));

bus.addListener(
NodeRestartedEvent.listener(
previousNodeStatus -> {
List<SessionId> toRemove =
knownSessions.entrySet().stream()
.filter(
(e) -> e.getValue().getUri().equals(previousNodeStatus.getExternalUri()))
.map(Map.Entry::getKey)
.collect(Collectors.toList());

toRemove.forEach(this::remove);
batchRemoveByUri(previousNodeStatus.getExternalUri(), NodeRestartedEvent.class);
}));
}

Expand All @@ -95,17 +91,23 @@ public boolean isReady() {
public boolean add(Session session) {
Require.nonNull("Session", session);

SessionId id = session.getId();
knownSessions.put(id, session);

try (Span span = tracer.getCurrentContext().createSpan("local_sessionmap.add")) {
AttributeMap attributeMap = tracer.createAttributeMap();
attributeMap.put(AttributeKey.LOGGER_CLASS.getKey(), getClass().getName());
SessionId id = session.getId();
SESSION_ID.accept(span, id);
SESSION_ID_EVENT.accept(attributeMap, id);
knownSessions.put(session.getId(), session);
span.addEvent("Added session into local session map", attributeMap);

return true;
String sessionAddedMessage =
String.format(
"Added session to local Session Map, Id: %s, Node: %s", id, session.getUri());
span.addEvent(sessionAddedMessage, attributeMap);
LOG.info(sessionAddedMessage);
}

return true;
}

@Override
Expand All @@ -116,23 +118,157 @@ public Session get(SessionId id) {
if (session == null) {
throw new NoSuchSessionException("Unable to find session with ID: " + id);
}

return session;
}

@Override
public void remove(SessionId id) {
Require.nonNull("Session ID", id);

Session removedSession = knownSessions.remove(id);

try (Span span = tracer.getCurrentContext().createSpan("local_sessionmap.remove")) {
AttributeMap attributeMap = tracer.createAttributeMap();
attributeMap.put(AttributeKey.LOGGER_CLASS.getKey(), getClass().getName());
SESSION_ID.accept(span, id);
SESSION_ID_EVENT.accept(attributeMap, id);
knownSessions.remove(id);
String sessionDeletedMessage = "Deleted session from local Session Map";

String sessionDeletedMessage =
String.format(
"Deleted session from local Session Map, Id: %s, Node: %s",
id,
removedSession != null ? String.valueOf(removedSession.getUri()) : "unidentified");
span.addEvent(sessionDeletedMessage, attributeMap);
LOG.info(String.format("%s, Id: %s", sessionDeletedMessage, id));
LOG.info(sessionDeletedMessage);
}
}

private void batchRemoveByUri(URI externalUri, Class<? extends Event> eventClass) {
Set<SessionId> sessionsToRemove = knownSessions.getSessionsByUri(externalUri);

if (sessionsToRemove.isEmpty()) {
return; // Early return for empty operations - no tracing overhead
}

knownSessions.batchRemove(sessionsToRemove);

try (Span span = tracer.getCurrentContext().createSpan("local_sessionmap.batch_remove")) {
AttributeMap attributeMap = tracer.createAttributeMap();
attributeMap.put(AttributeKey.LOGGER_CLASS.getKey(), getClass().getName());
attributeMap.put("event.class", eventClass.getName());
attributeMap.put("node.uri", externalUri.toString());
attributeMap.put("sessions.count", sessionsToRemove.size());

String batchRemoveMessage =
String.format(
"Batch removed %d sessions from local Session Map for Node %s (triggered by %s)",
sessionsToRemove.size(), externalUri, eventClass.getSimpleName());
span.addEvent(batchRemoveMessage, attributeMap);
LOG.info(batchRemoveMessage);
}
}

private static class IndexedSessionMap {
private final ConcurrentMap<SessionId, Session> sessions = new ConcurrentHashMap<>();
private final ConcurrentMap<URI, Set<SessionId>> sessionsByUri = new ConcurrentHashMap<>();
private final Object coordinationLock = new Object();

public Session get(SessionId id) {
return sessions.get(id);
}

public Session put(SessionId id, Session session) {
synchronized (coordinationLock) {
Session previous = sessions.put(id, session);

if (previous != null && previous.getUri() != null) {
cleanupUriIndex(previous.getUri(), id);
}

URI sessionUri = session.getUri();
if (sessionUri != null) {
sessionsByUri.computeIfAbsent(sessionUri, k -> ConcurrentHashMap.newKeySet()).add(id);
}

return previous;
}
}

public Session remove(SessionId id) {
synchronized (coordinationLock) {
Session removed = sessions.remove(id);

if (removed != null && removed.getUri() != null) {
cleanupUriIndex(removed.getUri(), id);
}

return removed;
}
}

public void batchRemove(Set<SessionId> sessionIds) {
synchronized (coordinationLock) {
Map<URI, Set<SessionId>> uriToSessionIds = new HashMap<>();

// Single loop: remove sessions and collect URI mappings in one pass
for (SessionId id : sessionIds) {
Session session = sessions.remove(id);
if (session != null && session.getUri() != null) {
uriToSessionIds.computeIfAbsent(session.getUri(), k -> new HashSet<>()).add(id);
}
}

// Clean up URI index for all affected URIs
for (Map.Entry<URI, Set<SessionId>> entry : uriToSessionIds.entrySet()) {
cleanupUriIndex(entry.getKey(), entry.getValue());
}
}
}

private void cleanupUriIndex(URI uri, SessionId sessionId) {
sessionsByUri.computeIfPresent(
uri,
(key, sessionIds) -> {
sessionIds.remove(sessionId);
return sessionIds.isEmpty() ? null : sessionIds;
});
}

private void cleanupUriIndex(URI uri, Set<SessionId> sessionIdsToRemove) {
sessionsByUri.computeIfPresent(
uri,
(key, sessionIds) -> {
sessionIds.removeAll(sessionIdsToRemove);
return sessionIds.isEmpty() ? null : sessionIds;
});
}

public Set<SessionId> getSessionsByUri(URI uri) {
Set<SessionId> result = sessionsByUri.get(uri);
return (result != null && !result.isEmpty()) ? result : Set.of();
}

public Set<Map.Entry<SessionId, Session>> entrySet() {
return Collections.unmodifiableSet(sessions.entrySet());
}

public Collection<Session> values() {
return Collections.unmodifiableCollection(sessions.values());
}

public int size() {
return sessions.size();
}

public boolean isEmpty() {
return sessions.isEmpty();
}

public void clear() {
synchronized (coordinationLock) {
sessions.clear();
sessionsByUri.clear();
}
}
}
}
20 changes: 20 additions & 0 deletions java/test/org/openqa/selenium/grid/sessionmap/local/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
load("@rules_jvm_external//:defs.bzl", "artifact")
load("//java:defs.bzl", "JUNIT5_DEPS", "java_test_suite")

java_test_suite(
name = "SmallTests",
size = "medium",
srcs = glob(["*.java"]),
deps = [
"//java/src/org/openqa/selenium:core",
"//java/src/org/openqa/selenium/events",
"//java/src/org/openqa/selenium/events/local",
"//java/src/org/openqa/selenium/grid/data",
"//java/src/org/openqa/selenium/grid/sessionmap",
"//java/src/org/openqa/selenium/grid/sessionmap/local",
"//java/src/org/openqa/selenium/remote",
"//java/test/org/openqa/selenium/remote/tracing:tracing-support",
artifact("org.assertj:assertj-core"),
artifact("org.junit.jupiter:junit-jupiter-api"),
] + JUNIT5_DEPS,
)
Loading
Loading