5959import  java .util .concurrent .Executors ;
6060import  java .util .concurrent .ScheduledExecutorService ;
6161import  java .util .concurrent .TimeUnit ;
62- import  java .util .concurrent .atomic .AtomicBoolean ;
6362import  java .util .concurrent .atomic .AtomicInteger ;
6463import  java .util .concurrent .atomic .AtomicLong ;
64+ import  java .util .concurrent .locks .Lock ;
65+ import  java .util .concurrent .locks .ReadWriteLock ;
66+ import  java .util .concurrent .locks .ReentrantReadWriteLock ;
6567import  java .util .logging .Level ;
6668import  java .util .logging .Logger ;
6769import  java .util .stream .Collectors ;
@@ -133,7 +135,7 @@ public class LocalNode extends Node implements Closeable {
133135  private  final  int  connectionLimitPerSession ;
134136
135137  private  final  boolean  bidiEnabled ;
136-   private  final  AtomicBoolean  drainAfterSessions  =  new   AtomicBoolean () ;
138+   private  final  boolean  drainAfterSessions ;
137139  private  final  List <SessionSlot > factories ;
138140  private  final  Cache <SessionId , SessionSlot > currentSessions ;
139141  private  final  Cache <SessionId , TemporaryFilesystem > uploadsTempFileSystem ;
@@ -142,6 +144,7 @@ public class LocalNode extends Node implements Closeable {
142144  private  final  AtomicInteger  pendingSessions  = new  AtomicInteger ();
143145  private  final  AtomicInteger  sessionCount  = new  AtomicInteger ();
144146  private  final  Runnable  shutdown ;
147+   private  final  ReadWriteLock  drainLock  = new  ReentrantReadWriteLock ();
145148
146149  protected  LocalNode (
147150      Tracer  tracer ,
@@ -177,7 +180,7 @@ protected LocalNode(
177180    this .factories  = ImmutableList .copyOf (factories );
178181    Require .nonNull ("Registration secret" , registrationSecret );
179182    this .configuredSessionCount  = drainAfterSessionCount ;
180-     this .drainAfterSessions . set ( this .configuredSessionCount  > 0 ) ;
183+     this .drainAfterSessions  =  this .configuredSessionCount  > 0 ;
181184    this .sessionCount .set (drainAfterSessionCount );
182185    this .cdpEnabled  = cdpEnabled ;
183186    this .bidiEnabled  = bidiEnabled ;
@@ -443,6 +446,9 @@ public Either<WebDriverException, CreateSessionResponse> newSession(
443446      CreateSessionRequest  sessionRequest ) {
444447    Require .nonNull ("Session request" , sessionRequest );
445448
449+     Lock  lock  = drainLock .readLock ();
450+     lock .lock ();
451+ 
446452    try  (Span  span  = tracer .getCurrentContext ().createSpan ("node.new_session" )) {
447453      AttributeMap  attributeMap  = tracer .createAttributeMap ();
448454      attributeMap .put (AttributeKey .LOGGER_CLASS .getKey (), getClass ().getName ());
@@ -455,13 +461,14 @@ public Either<WebDriverException, CreateSessionResponse> newSession(
455461      span .setAttribute ("current.session.count" , currentSessionCount );
456462      attributeMap .put ("current.session.count" , currentSessionCount );
457463
458-       if  (getCurrentSessionCount ()  >= maxSessionCount ) {
464+       if  (currentSessionCount  >= maxSessionCount ) {
459465        span .setAttribute (AttributeKey .ERROR .getKey (), true );
460466        span .setStatus (Status .RESOURCE_EXHAUSTED );
461467        attributeMap .put ("max.session.count" , maxSessionCount );
462468        span .addEvent ("Max session count reached" , attributeMap );
463469        return  Either .left (new  RetrySessionRequestException ("Max session count reached." ));
464470      }
471+ 
465472      if  (isDraining ()) {
466473        span .setStatus (
467474            Status .UNAVAILABLE .withDescription (
@@ -492,6 +499,15 @@ public Either<WebDriverException, CreateSessionResponse> newSession(
492499            new  RetrySessionRequestException ("No slot matched the requested capabilities." ));
493500      }
494501
502+       if  (!decrementSessionCount ()) {
503+         slotToUse .release ();
504+         span .setAttribute (AttributeKey .ERROR .getKey (), true );
505+         span .setStatus (Status .RESOURCE_EXHAUSTED );
506+         attributeMap .put ("drain.after.session.count" , configuredSessionCount );
507+         span .addEvent ("Drain after session count reached" , attributeMap );
508+         return  Either .left (new  RetrySessionRequestException ("Drain after session count reached." ));
509+       }
510+ 
495511      UUID  uuidForSessionDownloads  = UUID .randomUUID ();
496512      Capabilities  desiredCapabilities  = sessionRequest .getDesiredCapabilities ();
497513      if  (managedDownloadsRequested (desiredCapabilities )) {
@@ -548,6 +564,7 @@ public Either<WebDriverException, CreateSessionResponse> newSession(
548564        return  Either .left (possibleSession .left ());
549565      }
550566    } finally  {
567+       lock .unlock ();
551568      checkSessionCount ();
552569    }
553570  }
@@ -1020,20 +1037,40 @@ public void drain() {
10201037  }
10211038
10221039  private  void  checkSessionCount () {
1023-     if  (this .drainAfterSessions .get ()) {
1040+     if  (this .drainAfterSessions ) {
1041+       Lock  lock  = drainLock .writeLock ();
1042+       if  (!lock .tryLock ()) {
1043+         // in case we can't get a write lock another thread does hold a read lock and will call 
1044+         // checkSessionCount as soon as he releases the read lock. So we do not need to wait here 
1045+         // for the other session to start and release the lock, just continue and let the other 
1046+         // session start to drain the node. 
1047+         return ;
1048+       }
1049+       try  {
1050+         int  remainingSessions  = this .sessionCount .get ();
1051+         if  (remainingSessions  <= 0 ) {
1052+           LOG .info (
1053+               String .format (
1054+                   "Draining Node, configured sessions value (%s) has been reached." ,
1055+                   this .configuredSessionCount ));
1056+           drain ();
1057+         }
1058+       } finally  {
1059+         lock .unlock ();
1060+       }
1061+     }
1062+   }
1063+ 
1064+   private  boolean  decrementSessionCount () {
1065+     if  (this .drainAfterSessions ) {
10241066      int  remainingSessions  = this .sessionCount .decrementAndGet ();
10251067      LOG .log (
10261068          Debug .getDebugLogLevel (),
10271069          "{0} remaining sessions before draining Node" ,
10281070          remainingSessions );
1029-       if  (remainingSessions  <= 0 ) {
1030-         LOG .info (
1031-             String .format (
1032-                 "Draining Node, configured sessions value (%s) has been reached." ,
1033-                 this .configuredSessionCount ));
1034-         drain ();
1035-       }
1071+       return  remainingSessions  >= 0 ;
10361072    }
1073+     return  true ;
10371074  }
10381075
10391076  private  Map <String , Object > toJson () {
0 commit comments