@@ -838,12 +838,71 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
838838 return nil , fmt .Errorf ("expiry must be in the future" )
839839 }
840840
841+ // If the privacy mapper is being used for this session, then we need
842+ // to keep track of all our known privacy map pairs for this session
843+ // along with any new pairs that we need to persist.
841844 var (
842845 privacy = ! req .NoPrivacyMapper
843- privacyMapPairs = make (map [string ]string )
844846 knownPrivMapPairs = firewalldb .NewPrivacyMapPairs (nil )
847+ newPrivMapPairs = make (map [string ]string )
845848 )
846849
850+ // If a previous session ID has been set to link this new one to, we
851+ // first check if we have the referenced session, and we make sure it
852+ // has been revoked.
853+ var (
854+ linkedGroupID * session.ID
855+ linkedGroupSession * session.Session
856+ )
857+ if len (req .LinkedGroupId ) != 0 {
858+ var groupID session.ID
859+ copy (groupID [:], req .LinkedGroupId )
860+
861+ // Check that the group actually does exist.
862+ groupSess , err := s .cfg .db .GetSessionByID (groupID )
863+ if err != nil {
864+ return nil , err
865+ }
866+
867+ // Ensure that the linked session is in fact the first session
868+ // in its group.
869+ if groupSess .ID != groupSess .GroupID {
870+ return nil , fmt .Errorf ("can not link to session " +
871+ "%x since it is not the first in the session " +
872+ "group %x" , groupSess .ID , groupSess .GroupID )
873+ }
874+
875+ // Now we need to check that all the sessions in the group are
876+ // no longer active.
877+ ok , err := s .cfg .db .CheckSessionGroupPredicate (
878+ groupID , func (s * session.Session ) bool {
879+ return s .State == session .StateRevoked ||
880+ s .State == session .StateExpired
881+ },
882+ )
883+ if err != nil {
884+ return nil , err
885+ }
886+
887+ if ! ok {
888+ return nil , fmt .Errorf ("a linked session in group " +
889+ "%x is still active" , groupID )
890+ }
891+
892+ linkedGroupID = & groupID
893+ linkedGroupSession = groupSess
894+
895+ privDB := s .cfg .privMap (groupID )
896+ err = privDB .View (func (tx firewalldb.PrivacyMapTx ) error {
897+ knownPrivMapPairs , err = tx .FetchAllPairs ()
898+
899+ return err
900+ })
901+ if err != nil {
902+ return nil , err
903+ }
904+ }
905+
847906 // First need to fetch all the perms that need to be baked into this
848907 // mac based on the features.
849908 allFeatures , err := s .cfg .autopilot .ListFeatures (ctx )
@@ -892,8 +951,21 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
892951 return nil , err
893952 }
894953
954+ // Store the new privacy map pairs in
955+ // the newPrivMap pairs map so that
956+ // they are later persisted to the real
957+ // priv map db.
895958 for k , v := range privMapPairs {
896- privacyMapPairs [k ] = v
959+ newPrivMapPairs [k ] = v
960+ }
961+
962+ // Also add the new pairs to the known
963+ // set of pairs.
964+ err = knownPrivMapPairs .Add (
965+ privMapPairs ,
966+ )
967+ if err != nil {
968+ return nil , err
897969 }
898970 }
899971
@@ -1017,52 +1089,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
10171089 caveats = append (caveats , firewall .MetaPrivacyCaveat )
10181090 }
10191091
1020- // If a previous session ID has been set to link this new one to, we
1021- // first check if we have the referenced session, and we make sure it
1022- // has been revoked.
1023- var (
1024- linkedGroupID * session.ID
1025- linkedGroupSession * session.Session
1026- )
1027- if len (req .LinkedGroupId ) != 0 {
1028- var groupID session.ID
1029- copy (groupID [:], req .LinkedGroupId )
1030-
1031- // Check that the group actually does exist.
1032- groupSess , err := s .cfg .db .GetSessionByID (groupID )
1033- if err != nil {
1034- return nil , err
1035- }
1036-
1037- // Ensure that the linked session is in fact the first session
1038- // in its group.
1039- if groupSess .ID != groupSess .GroupID {
1040- return nil , fmt .Errorf ("can not link to session " +
1041- "%x since it is not the first in the session " +
1042- "group %x" , groupSess .ID , groupSess .GroupID )
1043- }
1044-
1045- // Now we need to check that all the sessions in the group are
1046- // no longer active.
1047- ok , err := s .cfg .db .CheckSessionGroupPredicate (
1048- groupID , func (s * session.Session ) bool {
1049- return s .State == session .StateRevoked ||
1050- s .State == session .StateExpired
1051- },
1052- )
1053- if err != nil {
1054- return nil , err
1055- }
1056-
1057- if ! ok {
1058- return nil , fmt .Errorf ("a linked session in group " +
1059- "%x is still active" , groupID )
1060- }
1061-
1062- linkedGroupID = & groupID
1063- linkedGroupSession = groupSess
1064- }
1065-
10661092 s .sessRegMu .Lock ()
10671093 defer s .sessRegMu .Unlock ()
10681094
@@ -1101,7 +1127,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
11011127 // Register all the privacy map pairs for this session ID.
11021128 privDB := s .cfg .privMap (sess .GroupID )
11031129 err = privDB .Update (func (tx firewalldb.PrivacyMapTx ) error {
1104- for r , p := range privacyMapPairs {
1130+ for r , p := range newPrivMapPairs {
11051131 err := tx .NewPair (r , p )
11061132 if err != nil {
11071133 return err
0 commit comments