Skip to content

Commit 251de28

Browse files
committed
Gracefully terminate watcher process on windows
1 parent 6ddaf00 commit 251de28

File tree

7 files changed

+238
-31
lines changed

7 files changed

+238
-31
lines changed

internal/pkg/agent/application/upgrade/watcher.go

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"context"
99
"errors"
1010
"fmt"
11-
"os"
1211
"os/exec"
1312
"path/filepath"
1413
"time"
@@ -20,8 +19,6 @@ import (
2019
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details"
2120
"github.com/elastic/elastic-agent/pkg/control/v2/client"
2221
"github.com/elastic/elastic-agent/pkg/core/logger"
23-
"github.com/elastic/elastic-agent/pkg/core/process"
24-
"github.com/elastic/elastic-agent/pkg/utils"
2522
)
2623

2724
const (
@@ -286,17 +283,14 @@ func (a AgentWatcherHelper) WaitForWatcher(ctx context.Context, log *logger.Logg
286283
}
287284

288285
func (a AgentWatcherHelper) TakeOverWatcher(ctx context.Context, log *logger.Logger, topDir string) (*filelock.AppLocker, error) {
289-
return takeOverWatcher(ctx, log, topDir, utils.GetWatcherPIDs, 30*time.Second, 500*time.Millisecond, 100*time.Millisecond)
286+
return takeOverWatcher(ctx, log, topDir, 30*time.Second, 500*time.Millisecond, 100*time.Millisecond)
290287
}
291288

292-
// watcherPIDsFetcher defines the type of function responsible for fetching watcher PIDs.
293-
// This will allow for easier testing of takeOverWatcher using fake binaries
294-
type watcherPIDsFetcher func() ([]int, error)
295-
296289
// Private functions providing implementation of AgentWatcherHelper
297-
func takeOverWatcher(ctx context.Context, log *logger.Logger, topDir string, pidFetchFunc watcherPIDsFetcher, timeout time.Duration, watcherSweepInterval time.Duration, takeOverInterval time.Duration) (*filelock.AppLocker, error) {
290+
func takeOverWatcher(ctx context.Context, log *logger.Logger, topDir string, timeout time.Duration, watcherSweepInterval time.Duration, takeOverInterval time.Duration) (*filelock.AppLocker, error) {
298291
takeoverCtx, takeoverCancel := context.WithTimeout(ctx, timeout)
299292
defer takeoverCancel()
293+
300294
go func() {
301295
sweepTicker := time.NewTicker(watcherSweepInterval)
302296
defer sweepTicker.Stop()
@@ -305,27 +299,15 @@ func takeOverWatcher(ctx context.Context, log *logger.Logger, topDir string, pid
305299
case <-takeoverCtx.Done():
306300
return
307301
case <-sweepTicker.C:
308-
pids, err := pidFetchFunc()
302+
cmd := createTakeDownWatcherCommand(takeoverCtx)
303+
log.Debugf("launching takedown with %v", cmd.Args)
304+
output, err := cmd.CombinedOutput()
305+
log.Debugf("takedown output: %s", string(output))
309306
if err != nil {
310-
log.Errorf("error listing watcher processes: %s", err)
307+
log.Errorf("error taking down watcher: %s", err)
311308
continue
312309
}
313310

314-
// this should be run continuously and concurrently attempting to get the app locker
315-
for _, pid := range pids {
316-
log.Debugf("attempting to kill watcher process with PID: %d", pid)
317-
watcherProcess, findProcErr := os.FindProcess(pid)
318-
if findProcErr != nil {
319-
log.Errorf("error finding process with PID: %d: %s", pid, findProcErr)
320-
continue
321-
}
322-
killProcErr := process.Terminate(watcherProcess)
323-
if killProcErr != nil {
324-
log.Errorf("error killing process with PID: %d: %s", pid, killProcErr)
325-
continue
326-
}
327-
log.Debugf("killed watcher process with PID: %d", pid)
328-
}
329311
}
330312
}
331313
}()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
2+
// or more contributor license agreements. Licensed under the Elastic License 2.0;
3+
// you may not use this file except in compliance with the Elastic License 2.0.
4+
5+
//go:build !windows
6+
7+
package upgrade
8+
9+
import (
10+
"context"
11+
"os"
12+
"os/exec"
13+
14+
"github.com/elastic/elastic-agent/internal/pkg/agent/application/paths"
15+
)
16+
17+
func createTakeDownWatcherCommand(ctx context.Context) *exec.Cmd {
18+
executable, _ := os.Executable()
19+
20+
// #nosec G204 -- user cannot inject any parameters to this command
21+
cmd := exec.CommandContext(ctx, executable, watcherSubcommand,
22+
"--path.config", paths.Config(),
23+
"--path.home", paths.Top(),
24+
"--takedown",
25+
)
26+
return cmd
27+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
2+
// or more contributor license agreements. Licensed under the Elastic License 2.0;
3+
// you may not use this file except in compliance with the Elastic License 2.0.
4+
5+
//go:build windows
6+
7+
package upgrade
8+
9+
import (
10+
"context"
11+
"os"
12+
"os/exec"
13+
"syscall"
14+
15+
"golang.org/x/sys/windows"
16+
17+
"github.com/elastic/elastic-agent/internal/pkg/agent/application/paths"
18+
)
19+
20+
func createTakeDownWatcherCommand(ctx context.Context) *exec.Cmd {
21+
executable, _ := os.Executable()
22+
23+
// #nosec G204 -- user cannot inject any parameters to this command
24+
cmd := exec.CommandContext(ctx, executable, watcherSubcommand,
25+
"--path.config", paths.Config(),
26+
"--path.home", paths.Top(),
27+
"--takedown",
28+
)
29+
cmd.SysProcAttr = &syscall.SysProcAttr{
30+
// https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags
31+
CreationFlags: windows.DETACHED_PROCESS,
32+
}
33+
return cmd
34+
}

internal/pkg/agent/cmd/watch.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/elastic/elastic-agent-libs/logp"
1717
"github.com/elastic/elastic-agent-libs/logp/configure"
1818
"github.com/elastic/elastic-agent/pkg/control/v2/client"
19+
"github.com/elastic/elastic-agent/pkg/utils"
1920

2021
"github.com/elastic/elastic-agent/internal/pkg/agent/application/filelock"
2122
"github.com/elastic/elastic-agent/internal/pkg/agent/application/paths"
@@ -42,7 +43,7 @@ func newWatchCommandWithArgs(_ []string, streams *cli.IOStreams) *cobra.Command
4243
Use: "watch",
4344
Short: "Watch the Elastic Agent for failures and initiate rollback",
4445
Long: `This command watches Elastic Agent for failures and initiates rollback if necessary.`,
45-
Run: func(_ *cobra.Command, _ []string) {
46+
Run: func(c *cobra.Command, _ []string) {
4647
cfg := getConfig(streams)
4748
log, err := configuredLogger(cfg, watcherName)
4849
if err != nil {
@@ -53,14 +54,25 @@ func newWatchCommandWithArgs(_ []string, streams *cli.IOStreams) *cobra.Command
5354
// Make sure to flush any buffered logs before we're done.
5455
defer log.Sync() //nolint:errcheck // flushing buffered logs is best effort.
5556

57+
takedown, _ := c.Flags().GetBool("takedown")
58+
if takedown {
59+
err = takedownWatcher(log, utils.GetWatcherPIDs)
60+
if err != nil {
61+
log.Errorf("error taking down watcher: %v", err)
62+
os.Exit(5)
63+
}
64+
return
65+
}
66+
5667
if err := watchCmd(log, paths.Top(), cfg.Settings.Upgrade.Watcher, new(upgradeAgentWatcher), new(upgradeInstallationModifier)); err != nil {
5768
log.Errorw("Watch command failed", "error.message", err)
5869
fmt.Fprintf(streams.Err, "Watch command failed: %v\n%s\n", err, troubleshootMessage())
5970
os.Exit(4)
6071
}
6172
},
6273
}
63-
74+
cmd.Flags().BoolP("takedown", "t", false, "Take down the running watcher")
75+
cmd.Flags().MarkHidden("takedown") //nolint:errcheck // not required
6476
return cmd
6577
}
6678

@@ -104,7 +116,7 @@ func watchCmd(log *logp.Logger, topDir string, cfg *configuration.UpgradeWatcher
104116

105117
if marker.DesiredOutcome == upgrade.OUTCOME_ROLLBACK {
106118
// TODO: there should be some sanity check in rollback functions like the installation we are going back to should exist and work
107-
log.Info("rolling back because of DesiredOutcome=%s", marker.DesiredOutcome.String())
119+
log.Infof("rolling back because of DesiredOutcome=%s", marker.DesiredOutcome.String())
108120
err = installModifier.Rollback(context.Background(), log, client.New(), paths.Top(), marker.PrevVersionedHome, marker.PrevHash)
109121
if err != nil {
110122
return fmt.Errorf("rolling back: %w", err)
@@ -118,7 +130,7 @@ func watchCmd(log *logp.Logger, topDir string, cfg *configuration.UpgradeWatcher
118130
marker.Details = details.NewDetails(marker.Version, details.StateRollback, actionID)
119131
}
120132
marker.Details.SetStateWithReason(details.StateRollback, details.ReasonManualRollback)
121-
err := upgrade.SaveMarker(dataDir, marker, true)
133+
err = upgrade.SaveMarker(dataDir, marker, true)
122134
if err != nil {
123135
return fmt.Errorf("saving marker after rolling back: %w", err)
124136
}

internal/pkg/agent/cmd/watch_impl.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func watch(ctx context.Context, tilGrace time.Duration, errorCheckInterval time.
4848
go agtWatcher.Run(ctx)
4949

5050
signals := make(chan os.Signal, 1)
51-
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
51+
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGHUP)
5252

5353
graceTimer := time.NewTimer(tilGrace)
5454
defer graceTimer.Stop()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
2+
// or more contributor license agreements. Licensed under the Elastic License 2.0;
3+
// you may not use this file except in compliance with the Elastic License 2.0.
4+
5+
//go:build !windows
6+
7+
package cmd
8+
9+
import (
10+
"os"
11+
"syscall"
12+
)
13+
14+
func takedownWatcher(log *logger.Logger, pidFetchFunc watcherPIDsFetcher) error {
15+
pids, err := pidFetchFunc()
16+
if err != nil {
17+
return fmt.Errorf("error listing watcher processes: %s", err)
18+
}
19+
20+
ownPID := os.Getpid()
21+
22+
for _, pid := range pids {
23+
24+
if pid == ownPID {
25+
continue
26+
}
27+
28+
log.Debugf("attempting to terminate watcher process with PID: %d", pid)
29+
30+
process, err := os.FindProcess(pid)
31+
if err != nil {
32+
log.Errorf("error finding watcher process with PID: %d: %s", pid, err)
33+
continue
34+
}
35+
36+
err = process.Signal(syscall.SIGTERM)
37+
if err != nil {
38+
log.Errorf("error killing watcher process with PID: %d: %s", pid, err)
39+
continue
40+
}
41+
42+
}
43+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
2+
// or more contributor license agreements. Licensed under the Elastic License 2.0;
3+
// you may not use this file except in compliance with the Elastic License 2.0.
4+
5+
//go:build windows
6+
7+
package cmd
8+
9+
import (
10+
"fmt"
11+
"os"
12+
"unsafe"
13+
14+
gowindows "golang.org/x/sys/windows"
15+
16+
"github.com/elastic/elastic-agent/pkg/core/logger"
17+
)
18+
19+
// watcherPIDsFetcher defines the type of function responsible for fetching watcher PIDs.
20+
// This will allow for easier testing of takeOverWatcher using fake binaries
21+
type watcherPIDsFetcher func() ([]int, error)
22+
23+
var (
24+
kernel32API = gowindows.NewLazySystemDLL("kernel32.dll")
25+
26+
freeConsoleProc = kernel32API.NewProc("FreeConsole")
27+
attachConsoleProc = kernel32API.NewProc("AttachConsole")
28+
procGetConsoleProcessList = kernel32API.NewProc("GetConsoleProcessList")
29+
)
30+
31+
func takedownWatcher(log *logger.Logger, pidFetchFunc watcherPIDsFetcher) error {
32+
pids, err := pidFetchFunc()
33+
if err != nil {
34+
return fmt.Errorf("error listing watcher processes: %s", err)
35+
}
36+
37+
ownPID := os.Getpid()
38+
39+
for _, pid := range pids {
40+
41+
if pid == ownPID {
42+
continue
43+
}
44+
45+
log.Debugf("attempting to terminate watcher process with PID: %d", pid)
46+
// define an anonymous function in order to leverage the defer() for freeing console and other housekeeping
47+
func() {
48+
49+
r1, _, consoleErr := freeConsoleProc.Call()
50+
if r1 == 0 {
51+
log.Errorf("error preemptively detaching from console: %s", consoleErr)
52+
}
53+
54+
r1, _, consoleErr = attachConsoleProc.Call(uintptr(pid))
55+
if r1 == 0 {
56+
log.Errorf("error attaching console to watcher process with PID: %d -> %s", pid, consoleErr)
57+
return
58+
}
59+
log.Infof("successfully attached console with PID: %d", pid)
60+
61+
defer func() {
62+
r1, _, consoleErr = freeConsoleProc.Call()
63+
if r1 == 0 {
64+
log.Errorf("error detaching from console: %s", consoleErr)
65+
} else {
66+
log.Infof("successfully detached from console of PID: %d", pid)
67+
}
68+
}()
69+
70+
list, consoleErr := GetConsoleProcessList()
71+
if consoleErr != nil {
72+
log.Errorf("error listing console processes: %s", consoleErr)
73+
}
74+
75+
log.Infof("Own PID: %d, Watcher pid %d, Process list on console: %v", os.Getpid(), pid, list)
76+
77+
// Normally we would want to send the Ctrl+Break event only to the watcher process but due to the fact that
78+
// the parent process of the watcher has already terminated, we have to hug it tightly and take it down with us
79+
// by specifying processGroupID=0
80+
killProcErr := gowindows.GenerateConsoleCtrlEvent(gowindows.CTRL_BREAK_EVENT, 0)
81+
82+
if killProcErr != nil {
83+
log.Errorf("error terminating process with PID: %d: %s", pid, killProcErr)
84+
return
85+
}
86+
}()
87+
88+
}
89+
return nil
90+
}
91+
92+
// GetConsoleProcessList retrieves the list of process IDs attached to the current console
93+
func GetConsoleProcessList() ([]uint32, error) {
94+
// Allocate a buffer for PIDs
95+
const maxProcs = 64
96+
pids := make([]uint32, maxProcs)
97+
98+
r1, _, err := procGetConsoleProcessList.Call(
99+
uintptr(unsafe.Pointer(&pids[0])),
100+
uintptr(maxProcs),
101+
)
102+
103+
count := uint32(r1)
104+
if count == 0 {
105+
return nil, err
106+
}
107+
108+
return pids[:count], nil
109+
}

0 commit comments

Comments
 (0)