Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
Expand All @@ -75,6 +76,9 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
// skipAutoUpdateVersion used as a placeholder for autoUpdateVersion in proto responses to indicate response contains no new updates
skipAutoUpdateVersion = "skip"
disableAutoUpdate = "disabled"
)

var ErrResetConnection = fmt.Errorf("reset connection")
Expand Down Expand Up @@ -198,6 +202,9 @@ type Engine struct {
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager

// auto-update
updateManager *updatemanager.UpdateManager
}

// Peer is an instance of the Connection Peer
Expand Down Expand Up @@ -306,6 +313,10 @@ func (e *Engine) Stop() error {
e.srWatcher.Close()
}

if e.updateManager != nil {
e.updateManager.Stop()
}

e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
Expand Down Expand Up @@ -692,10 +703,26 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}

func (e *Engine) handleAutoUpdateVersion(autoUpdateVersion string) {
if autoUpdateVersion != skipAutoUpdateVersion {
if e.updateManager == nil && autoUpdateVersion != disableAutoUpdate {
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder)
e.updateManager.Start(e.ctx)
} else if e.updateManager != nil && autoUpdateVersion == disableAutoUpdate {
e.updateManager.Stop()
e.updateManager = nil
}
if e.updateManager != nil {
e.updateManager.SetVersion(autoUpdateVersion)
}
}
}

func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()

e.handleAutoUpdateVersion(update.AutoUpdateVersion)
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
Expand Down
261 changes: 261 additions & 0 deletions client/internal/updatemanager/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
package updatemanager

import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"

v "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/internal/peer"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)

const (
latestVersion = "latest"
)

type UpdateInterface interface {
StopWatch()
SetDaemonVersion(newVersion string) bool
SetOnUpdateListener(updateFn func())
LatestVersion() *v.Version
StartFetcher()
}

type UpdateManager struct {
lastTrigger time.Time
statusRecorder *peer.Status
mgmUpdateChan chan struct{}
updateChannel chan struct{}
wg sync.WaitGroup
currentVersion string
updateFunc func(ctx context.Context, targetVersion string) error

cancel context.CancelFunc
update UpdateInterface

expectedVersion *v.Version
updateToLatestVersion bool
expectedVersionMutex sync.Mutex
}

func NewUpdateManager(statusRecorder *peer.Status) *UpdateManager {
manager := &UpdateManager{
statusRecorder: statusRecorder,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
updateFunc: triggerUpdate,
update: version.NewUpdate("nb/client"),
}
return manager
}

func (u *UpdateManager) WithCustomVersionUpdate(versionUpdate UpdateInterface) *UpdateManager {
u.update = versionUpdate
return u
}

func (u *UpdateManager) Start(ctx context.Context) {
if u.cancel != nil {
log.Errorf("UpdateManager already started")
return
}

go u.update.StartFetcher()
u.update.SetDaemonVersion(u.currentVersion)
u.update.SetOnUpdateListener(func() {
select {
case u.updateChannel <- struct{}{}:
default:
}
})

ctx, cancel := context.WithCancel(ctx)
u.cancel = cancel

u.wg.Add(1)
go u.updateLoop(ctx)
}

func (u *UpdateManager) SetVersion(expectedVersion string) {
if u.cancel == nil {
log.Errorf("UpdateManager not started")
return
}

u.expectedVersionMutex.Lock()
defer u.expectedVersionMutex.Unlock()
if expectedVersion == latestVersion {
u.updateToLatestVersion = true
u.expectedVersion = nil
} else {
expectedSemVer, err := v.NewVersion(expectedVersion)
if err != nil {
log.Errorf("Error parsing version: %v", err)
return
}
if u.expectedVersion.Equal(expectedSemVer) {
return
}
u.expectedVersion = expectedSemVer
u.updateToLatestVersion = false
}

select {
case u.mgmUpdateChan <- struct{}{}:
default:
}
}

func (u *UpdateManager) Stop() {
if u.cancel == nil {
return
}

u.cancel()
if u.update != nil {
u.update.StopWatch()
u.update = nil
}

u.wg.Wait()
}

func (u *UpdateManager) updateLoop(ctx context.Context) {
defer u.wg.Done()

for {
select {
case <-ctx.Done():
return
case <-u.mgmUpdateChan:
case <-u.updateChannel:
}

u.handleUpdate(ctx)
}
}

func (u *UpdateManager) handleUpdate(ctx context.Context) {
var updateVersion *v.Version

u.expectedVersionMutex.Lock()
expectedVersion := u.expectedVersion
useLatest := u.updateToLatestVersion
curLatestVersion := u.update.LatestVersion()
u.expectedVersionMutex.Unlock()

switch {
// Resolve "latest" to actual version
case useLatest:
if curLatestVersion == nil {
log.Tracef("Latest version not fetched yet")
return
}
updateVersion = curLatestVersion
// Update to specific version
case u.expectedVersion != nil:
updateVersion = expectedVersion
default:
log.Debugf("No expected version information set")
return
}

if !u.shouldUpdate(updateVersion) {
return
}

ctx, cancel := context.WithDeadline(ctx, time.Now().Add(time.Minute))
defer cancel()

u.lastTrigger = time.Now()
log.Debugf("Auto-update triggered, current version: %s, target version: %s", u.currentVersion, updateVersion)
u.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"Automatically updating client",
"Your client version is older than auto-update version set in Management, updating client now.",
nil,
)

err := u.updateFunc(ctx, updateVersion.String())
if err != nil {
log.Errorf("Error triggering auto-update: %v", err)
}
}

func (u *UpdateManager) shouldUpdate(updateVersion *v.Version) bool {
currentVersion, err := v.NewVersion(u.currentVersion)
if err != nil {
log.Errorf("Error checking for update, error parsing version `%s`: %v", u.currentVersion, err)
return false
}
if currentVersion.GreaterThanOrEqual(updateVersion) {
log.Debugf("Current version (%s) is equal to or higher than auto-update version (%s)", u.currentVersion, updateVersion)
return false
}

if time.Since(u.lastTrigger) < 5*time.Minute {
log.Tracef("No need to update")
return false
}

return true
}

func downloadFileToTemporaryDir(ctx context.Context, fileURL string) (string, error) { //nolint:unused
tempDir, err := os.MkdirTemp("", "netbird-installer-*")
if err != nil {
return "", fmt.Errorf("error creating temporary directory: %w", err)
}
fileNameParts := strings.Split(fileURL, "/")
out, err := os.Create(filepath.Join(tempDir, fileNameParts[len(fileNameParts)-1]))
if err != nil {
return "", fmt.Errorf("error creating temporary file: %w", err)
}
defer func() {
if err := out.Close(); err != nil {
log.Errorf("Error closing temporary file: %v", err)
}
}()

req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
if err != nil {
return "", fmt.Errorf("error creating file download request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("error downloading file: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
log.Errorf("Error closing response body: %v", err)
}
}()

_, err = io.Copy(out, resp.Body)
if err != nil {
return "", fmt.Errorf("error downloading file: %w", err)
}

log.Tracef("Downloaded update file to %s", out.Name())

return out.Name(), nil
}

func urlWithVersionArch(url, version string) string { //nolint:unused
url = strings.ReplaceAll(url, "%version", version)
url = strings.ReplaceAll(url, "%arch", runtime.GOARCH)
return url
}
Loading
Loading