Skip to content

Commit 5cab18b

Browse files
authored
fix: free old GUTI after Registration Complete and other quality fixes in AMF (#1174)
Signed-off-by: Guillaume Belanger <guillaume.belanger27@gmail.com>
1 parent 60340f0 commit 5cab18b

23 files changed

+273
-142
lines changed

internal/amf/amf.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import (
3131
// Authenticator is the interface the AMF requires from the AUSF.
3232
// *ausf.AUSF satisfies this interface directly.
3333
type Authenticator interface {
34-
Authenticate(ctx context.Context, suci, servingNetwork string, resync *ausf.ResyncInfo) (*ausf.AuthResult, error)
34+
Authenticate(ctx context.Context, suci string, plmn models.PlmnID, resync *ausf.ResyncInfo) (*ausf.AuthResult, error)
3535
Confirm(ctx context.Context, resStar, suci string) (etsi.SUPI, string, error)
3636
}
3737

@@ -241,10 +241,11 @@ func (amf *AMF) NewRadio(conn *sctp.SCTPConn) (*Radio, error) {
241241
SupportedTAIs: make([]SupportedTAI, 0),
242242
Conn: conn,
243243
ConnectedAt: now,
244-
LastSeenAt: now,
245244
Log: logger.AmfLog.With(logger.RanAddr(remoteAddr.String())),
246245
}
247246

247+
radio.SetLastSeenAt(now)
248+
248249
amf.mu.Lock()
249250
defer amf.mu.Unlock()
250251

@@ -524,7 +525,7 @@ func (amf *AMF) SendPaging(ctx context.Context, ue *AmfUe, ngapBuf []byte) error
524525
for _, ran := range amf.ListRadios() {
525526
for _, item := range ran.SupportedTAIs {
526527
if InTaiList(item.Tai, taiList) {
527-
err := ran.NGAPSender.SendToRan(ctx, ngapBuf, send.NGAPProcedurePaging)
528+
err := ran.NGAPSender.SendToRan(context.Background(), ngapBuf, send.NGAPProcedurePaging)
528529
if err != nil {
529530
ue.Log.Error("failed to send paging", zap.Error(err))
530531
continue

internal/amf/amf_ran.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package amf
1010
import (
1111
"context"
1212
"sync"
13+
"sync/atomic"
1314
"time"
1415

1516
"github.com/ellanetworks/core/internal/amf/ngap/send"
@@ -64,7 +65,7 @@ type Radio struct {
6465
Name string
6566
Conn *sctp.SCTPConn
6667
ConnectedAt time.Time
67-
LastSeenAt time.Time
68+
lastSeenAt atomic.Int64 // Unix nanoseconds; use GetLastSeenAt()/TouchLastSeen()
6869
SupportedTAIs []SupportedTAI
6970
mu sync.RWMutex // protects RanUEs
7071
RanUEs map[int64]*RanUe // Key: RanUeNgapID
@@ -127,7 +128,22 @@ func (r *Radio) SetRanID(ranNodeID *ngapType.GlobalRANNodeID) {
127128
}
128129

129130
func (r *Radio) TouchLastSeen() {
130-
r.LastSeenAt = time.Now()
131+
r.lastSeenAt.Store(time.Now().UnixNano())
132+
}
133+
134+
// GetLastSeenAt returns the last-seen timestamp. Safe for concurrent use.
135+
func (r *Radio) GetLastSeenAt() time.Time {
136+
ns := r.lastSeenAt.Load()
137+
if ns == 0 {
138+
return time.Time{}
139+
}
140+
141+
return time.Unix(0, ns)
142+
}
143+
144+
// SetLastSeenAt sets the last-seen timestamp. Safe for concurrent use.
145+
func (r *Radio) SetLastSeenAt(t time.Time) {
146+
r.lastSeenAt.Store(t.UnixNano())
131147
}
132148

133149
// NodeID returns the RAN node identifier string regardless of radio type.

internal/amf/amf_ran_test.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,18 @@ func TestRadioRanNodeTypeName(t *testing.T) {
3535
}
3636

3737
func TestRadioTouchLastSeen(t *testing.T) {
38-
radio := &amf.Radio{
39-
LastSeenAt: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
40-
}
38+
radio := &amf.Radio{}
39+
radio.SetLastSeenAt(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))
4140

4241
before := time.Now()
4342

4443
radio.TouchLastSeen()
4544

4645
after := time.Now()
4746

48-
if radio.LastSeenAt.Before(before) || radio.LastSeenAt.After(after) {
49-
t.Fatalf("expected LastSeenAt between %v and %v, got %v", before, after, radio.LastSeenAt)
47+
lastSeen := radio.GetLastSeenAt()
48+
if lastSeen.Before(before) || lastSeen.After(after) {
49+
t.Fatalf("expected LastSeenAt between %v and %v, got %v", before, after, lastSeen)
5050
}
5151
}
5252

@@ -57,23 +57,25 @@ func TestRadioTimestampsSetOnCreation(t *testing.T) {
5757
t.Fatal("expected ConnectedAt to be zero on a blank Radio")
5858
}
5959

60-
if !blank.LastSeenAt.IsZero() {
60+
if !blank.GetLastSeenAt().IsZero() {
6161
t.Fatal("expected LastSeenAt to be zero on a blank Radio")
6262
}
6363

6464
now := time.Now()
6565

6666
radio := &amf.Radio{
6767
ConnectedAt: now,
68-
LastSeenAt: now,
6968
}
7069

70+
radio.SetLastSeenAt(now)
71+
7172
if radio.ConnectedAt.IsZero() || radio.ConnectedAt != now {
7273
t.Fatalf("expected ConnectedAt to be %v, got %v", now, radio.ConnectedAt)
7374
}
7475

75-
if radio.LastSeenAt.IsZero() || radio.LastSeenAt != now {
76-
t.Fatalf("expected LastSeenAt to be %v, got %v", now, radio.LastSeenAt)
76+
lastSeen := radio.GetLastSeenAt()
77+
if lastSeen.IsZero() || !lastSeen.Equal(now) {
78+
t.Fatalf("expected LastSeenAt to be %v, got %v", now, lastSeen)
7779
}
7880
}
7981

internal/amf/amf_ue.go

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"encoding/binary"
1414
"encoding/hex"
1515
"fmt"
16-
"reflect"
1716
"slices"
1817
"sync"
1918
"time"
@@ -275,15 +274,15 @@ func (ue *AmfUe) AllocateRegistrationArea(supportedTais []models.Tai) {
275274
copy(taiList, supportedTais)
276275

277276
for _, supportTai := range taiList {
278-
if reflect.DeepEqual(supportTai, ue.Tai) {
277+
if supportTai.Equal(ue.Tai) {
279278
ue.RegistrationArea = append(ue.RegistrationArea, supportTai)
280279
break
281280
}
282281
}
283282
}
284283

285284
func (ue *AmfUe) IsAllowedNssai(targetSNssai *models.Snssai) bool {
286-
return reflect.DeepEqual(*ue.AllowedNssai, *targetSNssai)
285+
return ue.AllowedNssai.Equal(*targetSNssai)
287286
}
288287

289288
func (ue *AmfUe) SecurityContextIsValid() bool {
@@ -584,6 +583,9 @@ func (ue *AmfUe) CreateSmContext(pduSessionID uint8, ref string, snssai *models.
584583
return fmt.Errorf("invalid PDU session ID %d: must be in range 1-15 per TS 24.501", pduSessionID)
585584
}
586585

586+
ue.Mutex.Lock()
587+
defer ue.Mutex.Unlock()
588+
587589
ue.SmContextList[pduSessionID] = &SmContext{
588590
Ref: ref,
589591
Snssai: snssai,
@@ -592,12 +594,35 @@ func (ue *AmfUe) CreateSmContext(pduSessionID uint8, ref string, snssai *models.
592594
return nil
593595
}
594596

597+
func (ue *AmfUe) DeleteSmContext(pduSessionID uint8) {
598+
ue.Mutex.Lock()
599+
defer ue.Mutex.Unlock()
600+
601+
delete(ue.SmContextList, pduSessionID)
602+
}
603+
595604
func (ue *AmfUe) SmContextFindByPDUSessionID(pduSessionID uint8) (*SmContext, bool) {
605+
ue.Mutex.Lock()
606+
defer ue.Mutex.Unlock()
607+
596608
smContext, ok := ue.SmContextList[pduSessionID]
609+
597610
return smContext, ok
598611
}
599612

613+
func (ue *AmfUe) SetSmContextInactive(pduSessionID uint8) {
614+
ue.Mutex.Lock()
615+
defer ue.Mutex.Unlock()
616+
617+
if sc, ok := ue.SmContextList[pduSessionID]; ok {
618+
sc.PduSessionInactive = true
619+
}
620+
}
621+
600622
func (ue *AmfUe) HasActivePduSessions() bool {
623+
ue.Mutex.Lock()
624+
defer ue.Mutex.Unlock()
625+
601626
for _, smContext := range ue.SmContextList {
602627
if !smContext.PduSessionInactive {
603628
return true
@@ -933,8 +958,19 @@ func (ue *AmfUe) releaseSmContexts(ctx context.Context) {
933958
return
934959
}
935960

961+
// Copy refs under lock, then release lock before external SMF calls.
962+
ue.Mutex.Lock()
963+
964+
smContextRefs := make([]string, 0, len(ue.SmContextList))
936965
for _, smContext := range ue.SmContextList {
937-
err := ue.smf.ReleaseSmContext(ctx, smContext.Ref)
966+
smContextRefs = append(smContextRefs, smContext.Ref)
967+
}
968+
969+
ue.SmContextList = make(map[uint8]*SmContext)
970+
ue.Mutex.Unlock()
971+
972+
for _, smContextRef := range smContextRefs {
973+
err := ue.smf.ReleaseSmContext(ctx, smContextRef)
938974
if err != nil {
939975
ue.Log.Error("Release SmContext Error", zap.Error(err))
940976
}

internal/amf/nas/gmm/authentication_procedure.go

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package gmm
33
import (
44
"context"
55
"fmt"
6-
"strconv"
76

87
"github.com/ellanetworks/core/internal/amf"
98
"github.com/ellanetworks/core/internal/amf/nas/gmm/message"
@@ -16,14 +15,7 @@ func sendUEAuthenticationAuthenticateRequest(ctx context.Context, amfInstance *a
1615
return nil, fmt.Errorf("tai is not available in UE context")
1716
}
1817

19-
mnc, err := strconv.Atoi(ue.Tai.PlmnID.Mnc)
20-
if err != nil {
21-
return nil, fmt.Errorf("could not convert mnc to int: %v", err)
22-
}
23-
24-
snName := fmt.Sprintf("5G:mnc%03d.mcc%s.3gppnetwork.org", mnc, ue.Tai.PlmnID.Mcc)
25-
26-
ueAuthenticationCtx, err := amfInstance.Ausf.Authenticate(ctx, ue.Suci, snName, resyncInfo)
18+
ueAuthenticationCtx, err := amfInstance.Ausf.Authenticate(ctx, ue.Suci, *ue.Tai.PlmnID, resyncInfo)
2719
if err != nil {
2820
return nil, fmt.Errorf("ausf UE Authentication Authenticate Request failed: %s", err.Error())
2921
}

internal/amf/nas/gmm/handle_registration_complete.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ func handleRegistrationComplete(ctx context.Context, amfInstance *amf.AMF, ue *a
2222
ue.T3550 = nil // clear the timer
2323
}
2424

25+
// UE confirmed receipt of the new GUTI — free the old one (TS 24.501 5.5.1.2.4 step 20)
26+
amfInstance.FreeOldGuti(ue)
27+
2528
// Send NITZ (network name + timezone) to UE per TS 24.501
2629
message.SendConfigurationUpdateCommand(ctx, amfInstance, ue, false)
2730

internal/amf/nas/gmm/handle_service_request.go

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,24 @@ func sendServiceAccept(
112112

113113
// TS 24501 5.6.1
114114
func handleServiceRequest(ctx context.Context, amfInstance *amf.AMF, ue *amf.AmfUe, msg *nasMessage.ServiceRequest) error {
115-
if ue.GetState() != amf.Deregistered && ue.GetState() != amf.Registered {
115+
// TS 24.501 5.6.1.1: reject service request from deregistered UE
116+
if ue.GetState() == amf.Deregistered {
117+
err := message.SendServiceReject(ctx, ue.RanUe(), nasMessage.Cause5GMMUEIdentityCannotBeDerivedByTheNetwork)
118+
if err != nil {
119+
return fmt.Errorf("error sending service reject: %v", err)
120+
}
121+
122+
ue.RanUe().ReleaseAction = amf.UeContextN2NormalRelease
123+
124+
err = ue.RanUe().SendUEContextReleaseCommand(ctx, ngapType.CausePresentNas, ngapType.CauseNasPresentNormalRelease)
125+
if err != nil {
126+
return fmt.Errorf("error sending ue context release command: %v", err)
127+
}
128+
129+
return nil
130+
}
131+
132+
if ue.GetState() != amf.Registered {
116133
return fmt.Errorf("state mismatch: receive Service Request message in state %s", ue.GetState())
117134
}
118135

@@ -165,8 +182,8 @@ func handleServiceRequest(ctx context.Context, amfInstance *amf.AMF, ue *amf.Amf
165182
ue.RetransmissionOfInitialNASMsg = ue.MacFailed
166183
}
167184

168-
// Service Reject if the SecurityContext is invalid or the UE is Deregistered
169-
if !ue.SecurityContextIsValid() || ue.GetState() == amf.Deregistered {
185+
// Service Reject if the SecurityContext is invalid
186+
if !ue.SecurityContextIsValid() {
170187
ue.Log.Warn("No security context", logger.SUPI(ue.Supi.String()))
171188
ue.SecurityContextAvailable = false
172189

@@ -221,12 +238,21 @@ func handleServiceRequest(ctx context.Context, amfInstance *amf.AMF, ue *amf.Amf
221238
}
222239
}
223240

241+
// Copy SmContextList under lock for safe concurrent iteration.
242+
ue.Mutex.Lock()
243+
244+
smContextSnapshot := make(map[uint8]*amf.SmContext, len(ue.SmContextList))
245+
for id, sc := range ue.SmContextList {
246+
smContextSnapshot[id] = sc
247+
}
248+
ue.Mutex.Unlock()
249+
224250
// If the UE has uplink data pending for some PDU sessions, we need to activate them
225251
if msg.UplinkDataStatus != nil {
226252
uplinkDataPsi := nasConvert.PSIToBooleanArray(msg.UplinkDataStatus.Buffer)
227253
reactivationResult = new([16]bool)
228254

229-
for pduSessionID, smContext := range ue.SmContextList {
255+
for pduSessionID, smContext := range smContextSnapshot {
230256
if int(pduSessionID) >= len(uplinkDataPsi) {
231257
ue.Log.Warn("Ignoring out-of-range PDU session ID in UplinkDataStatus processing", zap.Uint8("pduSessionID", pduSessionID))
232258
continue
@@ -255,7 +281,7 @@ func handleServiceRequest(ctx context.Context, amfInstance *amf.AMF, ue *amf.Amf
255281
acceptPduSessionPsi = new([16]bool)
256282

257283
psiArray := nasConvert.PSIToBooleanArray(msg.PDUSessionStatus.Buffer)
258-
for pduSessionID, smContext := range ue.SmContextList {
284+
for pduSessionID, smContext := range smContextSnapshot {
259285
if int(pduSessionID) >= len(psiArray) {
260286
ue.Log.Warn("Ignoring out-of-range PDU session ID in PDUSessionStatus processing", zap.Uint8("pduSessionID", pduSessionID))
261287
continue

internal/amf/nas/gmm/handle_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ type FakeAusf struct {
261261
AvKgAka *ausf.AuthResult
262262
}
263263

264-
func (a *FakeAusf) Authenticate(ctx context.Context, suci string, servingNetwork string, resync *ausf.ResyncInfo) (*ausf.AuthResult, error) {
264+
func (a *FakeAusf) Authenticate(ctx context.Context, suci string, plmn models.PlmnID, resync *ausf.ResyncInfo) (*ausf.AuthResult, error) {
265265
if a.Error != nil {
266266
return nil, a.Error
267267
}

internal/amf/nas/gmm/handle_ul_nas_transport.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func transport5GSMMessage(ctx context.Context, amfInstance *amf.AMF, ue *amf.Amf
125125
if smContextExist && requestType != nil {
126126
/* AMF releases context locally as this is duplicate pdu session */
127127
if requestType.GetRequestTypeValue() == nasMessage.ULNASTransportRequestTypeInitialRequest {
128-
delete(ue.SmContextList, pduSessionID)
128+
ue.DeleteSmContext(pduSessionID)
129129

130130
smContextExist = false
131131
}

internal/amf/nas/gmm/message/send.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func SendAuthenticationRequest(ctx context.Context, amfInstance *amf.AMF, ue *am
110110
amfUe.T3560 = amf.NewTimer(cfg.ExpireTime, cfg.MaxRetryTimes, func(expireTimes int32) {
111111
amfUe.Log.Warn("T3560 expires, retransmit Authentication Request", zap.Any("expireTimes", expireTimes))
112112

113-
err := ue.SendDownlinkNasTransport(ctx, nasMsg, nil)
113+
err := ue.SendDownlinkNasTransport(context.Background(), nasMsg, nil)
114114
if err != nil {
115115
amfUe.Log.Error("could not send downlink NAS transport message", zap.Error(err))
116116
return
@@ -268,7 +268,7 @@ func SendSecurityModeCommand(ctx context.Context, amfInstance *amf.AMF, ue *amf.
268268
amfUe.T3560 = amf.NewTimer(cfg.ExpireTime, cfg.MaxRetryTimes, func(expireTimes int32) {
269269
amfUe.Log.Warn("T3560 expires, retransmit Security Mode Command", zap.Any("expireTimes", expireTimes))
270270

271-
err = ue.SendDownlinkNasTransport(ctx, nasMsg, nil)
271+
err = ue.SendDownlinkNasTransport(context.Background(), nasMsg, nil)
272272
if err != nil {
273273
amfUe.Log.Error("could not send downlink NAS transport message", zap.Error(err))
274274
return
@@ -277,7 +277,6 @@ func SendSecurityModeCommand(ctx context.Context, amfInstance *amf.AMF, ue *amf.
277277
amfUe.Log.Info("sent security mode command")
278278
}, func() {
279279
amfUe.Log.Warn("T3560 Expires, abort security mode control procedure", zap.Any("expireTimes", cfg.MaxRetryTimes))
280-
// amfUe.Remove()
281280
amfInstance.DeregisterAndRemoveAMFUE(context.Background(), amfUe)
282281
})
283282
}
@@ -380,7 +379,7 @@ func SendRegistrationAccept(
380379
} else {
381380
if ue.RanUe().UeContextRequest && !ue.RanUe().RecvdInitialContextSetupResponse {
382381
err = ue.RanUe().SendInitialContextSetupRequest(
383-
ctx,
382+
context.Background(),
384383
ue.Ambr.Uplink,
385384
ue.Ambr.Downlink,
386385
ue.AllowedNssai,
@@ -402,7 +401,7 @@ func SendRegistrationAccept(
402401
} else {
403402
ue.Log.Warn("T3550 expires, retransmit Registration Accept", zap.Any("expireTimes", expireTimes))
404403

405-
err = ue.RanUe().SendDownlinkNasTransport(ctx, nasMsg, nil)
404+
err = ue.RanUe().SendDownlinkNasTransport(context.Background(), nasMsg, nil)
406405
if err != nil {
407406
ue.Log.Error("could not send downlink NAS transport message", zap.Error(err))
408407
}
@@ -485,7 +484,7 @@ func SendConfigurationUpdateCommand(ctx context.Context, amfInstance *amf.AMF, a
485484
return
486485
}
487486

488-
err = amfUe.RanUe().SendDownlinkNasTransport(ctx, nasMsg, mobilityRestrictionList)
487+
err = amfUe.RanUe().SendDownlinkNasTransport(context.Background(), nasMsg, mobilityRestrictionList)
489488
if err != nil {
490489
amfUe.Log.Error("could not send configuration update command", zap.Error(err))
491490
}

0 commit comments

Comments
 (0)