@@ -10,15 +10,21 @@ use std::{
1010
1111use super :: {
1212 server_manager:: ServerManager ,
13- session:: { Session , SessionProcessList , SessionStatActivity , SessionState } ,
13+ session:: { Session , SessionState } ,
1414} ;
1515use crate :: compile:: DatabaseProtocol ;
1616
17+ #[ derive( Debug ) ]
18+ struct SessionManagerInner {
19+ sessions : HashMap < u32 , Arc < Session > > ,
20+ uid_to_session : HashMap < String , Arc < Session > > ,
21+ }
22+
1723#[ derive( Debug ) ]
1824pub struct SessionManager {
1925 // Sessions
2026 last_id : AtomicU32 ,
21- sessions : RWLockAsync < HashMap < u32 , Arc < Session > > > ,
27+ sessions : RWLockAsync < SessionManagerInner > ,
2228 pub temp_table_size : AtomicUsize ,
2329 // Backref
2430 pub server : Arc < ServerManager > ,
@@ -30,7 +36,10 @@ impl SessionManager {
3036 pub fn new ( server : Arc < ServerManager > ) -> Self {
3137 Self {
3238 last_id : AtomicU32 :: new ( 1 ) ,
33- sessions : RWLockAsync :: new ( HashMap :: new ( ) ) ,
39+ sessions : RWLockAsync :: new ( SessionManagerInner {
40+ sessions : HashMap :: new ( ) ,
41+ uid_to_session : HashMap :: new ( ) ,
42+ } ) ,
3443 temp_table_size : AtomicUsize :: new ( 0 ) ,
3544 server,
3645 }
@@ -41,60 +50,72 @@ impl SessionManager {
4150 protocol : DatabaseProtocol ,
4251 client_addr : String ,
4352 client_port : u16 ,
44- ) -> Arc < Session > {
53+ extra_id : Option < String > ,
54+ ) -> Result < Arc < Session > , CubeError > {
4555 let connection_id = self . last_id . fetch_add ( 1 , Ordering :: SeqCst ) ;
4656
47- let sess = Session {
57+ let session_ref = Arc :: new ( Session {
4858 session_manager : self . clone ( ) ,
4959 server : self . server . clone ( ) ,
5060 state : Arc :: new ( SessionState :: new (
5161 connection_id,
62+ extra_id. clone ( ) ,
5263 client_addr,
5364 client_port,
5465 protocol,
5566 None ,
5667 Duration :: from_secs ( self . server . config_obj . auth_expire_secs ( ) ) ,
5768 Arc :: downgrade ( self ) ,
5869 ) ) ,
59- } ;
60-
61- let session_ref = Arc :: new ( sess) ;
70+ } ) ;
6271
6372 let mut guard = self . sessions . write ( ) . await ;
6473
65- guard. insert ( connection_id, session_ref. clone ( ) ) ;
74+ if let Some ( extra_id) = extra_id {
75+ if guard. uid_to_session . contains_key ( & extra_id) {
76+ return Err ( CubeError :: user ( format ! (
77+ "Session cannot be created, because extra_id: {} already exists" ,
78+ extra_id
79+ ) ) ) ;
80+ }
6681
67- session_ref
68- }
82+ guard . uid_to_session . insert ( extra_id , session_ref. clone ( ) ) ;
83+ }
6984
70- pub async fn stat_activity ( self : & Arc < Self > ) -> Vec < SessionStatActivity > {
71- let guard = self . sessions . read ( ) . await ;
85+ guard. sessions . insert ( connection_id, session_ref. clone ( ) ) ;
7286
73- guard
74- . values ( )
75- . map ( Session :: to_stat_activity)
76- . collect :: < Vec < SessionStatActivity > > ( )
87+ Ok ( session_ref)
7788 }
7889
79- pub async fn process_list ( self : & Arc < Self > ) -> Vec < SessionProcessList > {
90+ pub async fn map_sessions < T : for < ' a > From < & ' a Arc < Session > > > ( self : & Arc < Self > ) -> Vec < T > {
8091 let guard = self . sessions . read ( ) . await ;
8192
8293 guard
94+ . sessions
8395 . values ( )
84- . map ( Session :: to_process_list )
85- . collect :: < Vec < SessionProcessList > > ( )
96+ . map ( |session| T :: from ( session ) )
97+ . collect :: < Vec < T > > ( )
8698 }
8799
88100 pub async fn get_session ( & self , connection_id : u32 ) -> Option < Arc < Session > > {
89101 let guard = self . sessions . read ( ) . await ;
90102
91- guard. get ( & connection_id) . map ( |s| s. clone ( ) )
103+ guard. sessions . get ( & connection_id) . map ( |s| s. clone ( ) )
104+ }
105+
106+ pub async fn get_session_by_extra_id ( & self , extra_id : String ) -> Option < Arc < Session > > {
107+ let guard = self . sessions . read ( ) . await ;
108+ guard. uid_to_session . get ( & extra_id) . map ( |s| s. clone ( ) )
92109 }
93110
94111 pub async fn drop_session ( & self , connection_id : u32 ) {
95112 let mut guard = self . sessions . write ( ) . await ;
96113
97- if let Some ( connection) = guard. remove ( & connection_id) {
114+ if let Some ( connection) = guard. sessions . remove ( & connection_id) {
115+ if let Some ( extra_id) = & connection. state . extra_id {
116+ guard. uid_to_session . remove ( extra_id) ;
117+ }
118+
98119 self . temp_table_size . fetch_sub (
99120 connection. state . temp_tables ( ) . physical_size ( ) ,
100121 Ordering :: SeqCst ,
0 commit comments