Skip to content

Commit b380f49

Browse files
committed
Fix deadlock issues
1 parent b30e2e7 commit b380f49

File tree

3 files changed

+147
-104
lines changed

3 files changed

+147
-104
lines changed

client/internal/engine.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ func NewEngine(
242242
statusRecorder: statusRecorder,
243243
checks: checks,
244244
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
245-
updateManager: updatemanager.NewUpdateManager(clientCtx, statusRecorder),
246245
}
247246

248247
sm := profilemanager.NewServiceManager("")
@@ -674,7 +673,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
674673
e.syncMsgMux.Lock()
675674
defer e.syncMsgMux.Unlock()
676675

677-
if update.GetAutoUpdateVersion() != "skip" {
676+
if e.updateManager == nil && update.GetAutoUpdateVersion() != "disabled" {
677+
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder)
678+
e.updateManager.Start(e.ctx)
679+
} else if e.updateManager != nil && update.GetAutoUpdateVersion() == "disabled" {
680+
e.updateManager.Stop()
681+
e.updateManager = nil
682+
}
683+
if e.updateManager != nil {
678684
e.updateManager.SetVersion(update.GetAutoUpdateVersion())
679685
}
680686
if update.GetNetbirdConfig() != nil {

client/internal/updatemanager/manager.go

Lines changed: 127 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -21,129 +21,174 @@ import (
2121
)
2222

2323
const (
24-
latestVersion = "latest"
25-
disableAutoUpdate = "disabled"
26-
unknownVersion = "Unknown"
24+
latestVersion = "latest"
2725
)
2826

2927
type UpdateManager struct {
30-
ctx context.Context
31-
cancel context.CancelFunc
32-
version string
33-
latestVersion string
34-
update *version.Update
3528
lastTrigger time.Time
3629
statusRecorder *peer.Status
37-
mutex sync.Mutex
38-
updateChannel chan string
39-
doneChannel chan struct{}
30+
mgmUpdateChan chan struct{}
31+
updateChannel chan struct{}
32+
wg sync.WaitGroup
33+
34+
cancel context.CancelFunc
35+
update *version.Update
36+
37+
expectedVersion string
38+
expectedVersionMutex sync.Mutex
4039
}
4140

42-
func NewUpdateManager(ctx context.Context, statusRecorder *peer.Status) *UpdateManager {
43-
update := version.NewUpdate("nb/client")
44-
ctx, cancel := context.WithCancel(ctx)
41+
func NewUpdateManager(statusRecorder *peer.Status) *UpdateManager {
4542
manager := &UpdateManager{
46-
update: update,
47-
lastTrigger: time.Now().Add(-10 * time.Minute),
4843
statusRecorder: statusRecorder,
49-
ctx: ctx,
50-
cancel: cancel,
51-
version: disableAutoUpdate,
52-
latestVersion: unknownVersion,
53-
updateChannel: make(chan string, 4),
54-
doneChannel: make(chan struct{}),
55-
}
56-
update.SetDaemonVersion(version.NetbirdVersion())
57-
update.SetOnUpdateChannel(manager.updateChannel)
58-
go manager.UpdateLoop()
44+
mgmUpdateChan: make(chan struct{}, 1),
45+
updateChannel: make(chan struct{}, 1),
46+
}
5947
return manager
6048
}
6149

50+
func (u *UpdateManager) Start(ctx context.Context) {
51+
if u.cancel != nil {
52+
log.Errorf("UpdateManager already started")
53+
return
54+
}
55+
56+
u.update = version.NewUpdate("nb/client")
57+
u.update.SetDaemonVersion(version.NetbirdVersion())
58+
u.update.SetOnUpdateListener(func() {
59+
select {
60+
case u.updateChannel <- struct{}{}:
61+
default:
62+
}
63+
})
64+
65+
ctx, cancel := context.WithCancel(ctx)
66+
u.cancel = cancel
67+
68+
u.wg.Add(1)
69+
go u.updateLoop(ctx)
70+
}
71+
6272
func (u *UpdateManager) SetVersion(v string) {
63-
u.mutex.Lock()
64-
if u.version != v {
65-
log.Tracef("Auto-update version set to %s", v)
66-
u.version = v
67-
u.mutex.Unlock()
68-
u.updateChannel <- unknownVersion
69-
} else {
70-
u.mutex.Unlock()
73+
if u.cancel == nil {
74+
log.Errorf("UpdateManager not started")
75+
return
76+
}
77+
78+
u.expectedVersionMutex.Lock()
79+
defer u.expectedVersionMutex.Unlock()
80+
if u.expectedVersion == v {
81+
return
82+
}
83+
84+
u.expectedVersion = v
85+
86+
select {
87+
case u.mgmUpdateChan <- struct{}{}:
88+
default:
7189
}
7290
}
7391

7492
func (u *UpdateManager) Stop() {
93+
if u.cancel == nil {
94+
return
95+
}
96+
7597
u.cancel()
76-
u.mutex.Lock()
77-
defer u.mutex.Unlock()
7898
if u.update != nil {
7999
u.update.StopWatch()
80100
u.update = nil
81101
}
82-
<-u.doneChannel
102+
103+
u.wg.Wait()
83104
}
84105

85-
func (u *UpdateManager) UpdateLoop() {
106+
func (u *UpdateManager) updateLoop(ctx context.Context) {
107+
defer u.wg.Done()
108+
86109
for {
87110
select {
88-
case <-u.ctx.Done():
89-
u.doneChannel <- struct{}{}
111+
case <-ctx.Done():
90112
return
91-
case latestVersion := <-u.updateChannel:
92-
u.mutex.Lock()
93-
if latestVersion != unknownVersion {
94-
u.latestVersion = latestVersion
95-
}
96-
u.mutex.Unlock()
97-
ctx, cancel := context.WithDeadline(u.ctx, time.Now().Add(time.Minute))
98-
u.CheckForUpdates(ctx)
99-
cancel()
113+
case <-u.mgmUpdateChan:
114+
case <-u.updateChannel:
100115
}
116+
117+
u.handleUpdate(ctx)
101118
}
102119
}
103120

104-
func (u *UpdateManager) CheckForUpdates(ctx context.Context) {
105-
if u.version == disableAutoUpdate {
106-
log.Trace("Skipped checking for updates, auto-update is disabled")
107-
return
108-
}
109-
currentVersionString := version.NetbirdVersion()
110-
updateVersionString := u.version
111-
if updateVersionString == latestVersion || updateVersionString == "" {
112-
if u.latestVersion == unknownVersion {
121+
func (u *UpdateManager) handleUpdate(ctx context.Context) {
122+
var updateVersion *v.Version
123+
124+
u.expectedVersionMutex.Lock()
125+
expectedVersion := u.expectedVersion
126+
u.expectedVersionMutex.Unlock()
127+
128+
// Resolve "latest" to actual version
129+
if expectedVersion == latestVersion {
130+
if !u.isVersionAvailable() {
113131
log.Tracef("Latest version not fetched yet")
114132
return
115133
}
116-
updateVersionString = u.latestVersion
134+
updateVersion = u.update.LatestVersion()
135+
} else {
136+
var err error
137+
updateVersion, err = v.NewSemver(expectedVersion)
138+
if err != nil {
139+
log.Errorf("Failed to parse latest version: %v", err)
140+
return
141+
}
142+
}
143+
144+
if !u.shouldUpdate(updateVersion) {
145+
return
146+
}
147+
148+
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(time.Minute))
149+
defer cancel()
150+
151+
u.lastTrigger = time.Now()
152+
log.Debugf("Auto-update triggered, current version: %s, target version: %s", version.NetbirdVersion(), updateVersion)
153+
u.statusRecorder.PublishEvent(
154+
cProto.SystemEvent_INFO,
155+
cProto.SystemEvent_SYSTEM,
156+
"Automatically updating client",
157+
"Your client version is older than auto-update version set in Management, updating client now.",
158+
nil,
159+
)
160+
161+
err := u.triggerUpdate(ctx, updateVersion.String())
162+
if err != nil {
163+
log.Errorf("Error triggering auto-update: %v", err)
117164
}
165+
}
166+
167+
func (u *UpdateManager) shouldUpdate(updateVersion *v.Version) bool {
168+
currentVersionString := version.NetbirdVersion()
118169
currentVersion, err := v.NewVersion(currentVersionString)
119170
if err != nil {
120171
log.Errorf("Error checking for update, error parsing version `%s`: %v", currentVersionString, err)
121-
return
172+
return false
122173
}
123-
updateVersion, err := v.NewVersion(updateVersionString)
124-
if err != nil {
125-
log.Errorf("Error checking for update, error parsing version `%s`: %v", updateVersionString, err)
126-
return
174+
if currentVersion.GreaterThanOrEqual(updateVersion) {
175+
log.Debugf("Current version (%s) is equal to or higher than auto-update version (%s)", currentVersionString, updateVersion)
176+
return false
127177
}
128-
if currentVersion.LessThan(updateVersion) {
129-
if u.lastTrigger.Add(5 * time.Minute).Before(time.Now()) {
130-
u.lastTrigger = time.Now()
131-
log.Debugf("Auto-update triggered, current version: %s, target version: %s", currentVersionString, updateVersionString)
132-
u.statusRecorder.PublishEvent(
133-
cProto.SystemEvent_INFO,
134-
cProto.SystemEvent_SYSTEM,
135-
"Automatically updating client",
136-
"Your client version is older than auto-update version set in Management, updating client now.",
137-
nil,
138-
)
139-
err = u.triggerUpdate(ctx, updateVersionString)
140-
if err != nil {
141-
log.Errorf("Error triggering auto-update: %v", err)
142-
}
143-
}
144-
} else {
145-
log.Debugf("Current version (%s) is equal to or higher than auto-update version (%s)", currentVersionString, updateVersionString)
178+
179+
if time.Since(u.lastTrigger) < 5*time.Minute {
180+
log.Tracef("No need to update")
181+
return false
182+
}
183+
184+
return true
185+
}
186+
187+
func (u *UpdateManager) isVersionAvailable() bool {
188+
if u.update.LatestVersion() == nil {
189+
return false
146190
}
191+
return true
147192
}
148193

149194
func downloadFileToTemporaryDir(ctx context.Context, fileURL string) (string, error) { //nolint:unused

version/update.go

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ type Update struct {
3131
fetchDone chan struct{}
3232

3333
onUpdateListener func()
34-
onUpdateChannel chan string
3534
listenerLock sync.Mutex
3635
}
3736

@@ -42,14 +41,11 @@ func NewUpdate(httpAgent string) *Update {
4241
currentVersion, _ = goversion.NewVersion("0.0.0")
4342
}
4443

45-
latestAvailable, _ := goversion.NewVersion("0.0.0")
46-
4744
u := &Update{
48-
httpAgent: httpAgent,
49-
latestAvailable: latestAvailable,
50-
uiVersion: currentVersion,
51-
fetchTicker: time.NewTicker(fetchPeriod),
52-
fetchDone: make(chan struct{}),
45+
httpAgent: httpAgent,
46+
uiVersion: currentVersion,
47+
fetchTicker: time.NewTicker(fetchPeriod),
48+
fetchDone: make(chan struct{}),
5349
}
5450
go u.startFetcher()
5551
return u
@@ -95,15 +91,10 @@ func (u *Update) SetOnUpdateListener(updateFn func()) {
9591
}
9692
}
9793

98-
func (u *Update) SetOnUpdateChannel(updateChannel chan string) {
99-
u.listenerLock.Lock()
100-
defer u.listenerLock.Unlock()
101-
u.onUpdateChannel = updateChannel
102-
if u.isUpdateAvailable() {
103-
u.versionsLock.Lock()
104-
defer u.versionsLock.Unlock()
105-
u.onUpdateChannel <- u.latestAvailable.String()
106-
}
94+
func (u *Update) LatestVersion() *goversion.Version {
95+
u.versionsLock.Lock()
96+
defer u.versionsLock.Unlock()
97+
return u.latestAvailable
10798
}
10899

109100
func (u *Update) startFetcher() {
@@ -181,9 +172,6 @@ func (u *Update) checkUpdate() bool {
181172

182173
u.listenerLock.Lock()
183174
defer u.listenerLock.Unlock()
184-
if u.onUpdateChannel != nil {
185-
u.onUpdateChannel <- u.latestAvailable.String()
186-
}
187175
if u.onUpdateListener == nil {
188176
return true
189177
}
@@ -196,6 +184,10 @@ func (u *Update) isUpdateAvailable() bool {
196184
u.versionsLock.Lock()
197185
defer u.versionsLock.Unlock()
198186

187+
if u.latestAvailable == nil {
188+
return false
189+
}
190+
199191
if u.latestAvailable.GreaterThan(u.uiVersion) {
200192
return true
201193
}

0 commit comments

Comments
 (0)