diff --git a/README.md b/README.md index 5e9e147..16b1284 100644 --- a/README.md +++ b/README.md @@ -155,6 +155,28 @@ func main() { * Currently shells out to `mv` for moving files because `mv` handles cross-partition moves unlike `os.Rename`. * Package `init()` functions will run twice on start, once in the main process and once in the child process. +## Windows Named Pipe Communication + +On Windows platform, since UNIX domain sockets and file descriptor passing mechanisms are not supported, Overseer uses Windows named pipes to implement communication between the main process and child processes. + +### Architecture Design + +1. **Main Process**: + - Creates a unique named pipe for each listening address + - Passes all pipe names as environment variables to the child process + - Accepts pipe connections from the child process, and creates dedicated data transfer channels for each address + +2. **Child Process**: + - Gets all pipe names and counts from environment variables + - Connects to corresponding dedicated pipes for each address + - Creates virtual listeners for each address, handling connections transparently + +3. **Connection Flow**: + - Client connects to the main process's TCP listener + - Main process directly forwards TCP connections to the dedicated pipe for the corresponding address + - Child process receives data from the pipe and processes the request + - No control information exchange is needed, data transmission is direct + ### More documentation * [Core `overseer` package](https://godoc.org/github.com/jpillora/overseer) diff --git a/example/go.mod b/example/go.mod index 8b4d90f..bb1faa0 100644 --- a/example/go.mod +++ b/example/go.mod @@ -1,6 +1,8 @@ module eg -go 1.13 +go 1.21 + +toolchain go1.23.2 replace github.com/jpillora/overseer => ../ diff --git a/example/go.sum b/example/go.sum index ad5212d..3118e9f 100644 --- a/example/go.sum +++ b/example/go.sum @@ -1,21 +1,72 @@ +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= +github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/aws/aws-sdk-go v1.29.34 h1:yrzwfDaZFe9oT4AmQeNNunSQA7c0m2chz0B43+bJ1ok= github.com/aws/aws-sdk-go v1.29.34/go.mod h1:1KvfttTE3SPKMpo8g2c6jL3ZKfXtFvKscTgahTma5Xg= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= +github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jpillora/s3 v1.1.4/go.mod h1:yedE603V+crlFi1Kl/5vZJaBu9pUzE9wvKegU/lF2zs= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/shirou/gopsutil v2.20.2+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/smartystreets/assertions v1.0.1/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM= +github.com/smartystreets/gunit v1.1.3/go.mod h1:EH5qMBab2UclzXUcpR8b93eHsIlp9u+pDQIRp5DZNzQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go.mod b/go.mod index 3dc7c72..69ea673 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,16 @@ module github.com/jpillora/overseer -go 1.13 +go 1.23.0 + +toolchain go1.23.2 require ( - github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d - github.com/go-ole/go-ole v1.2.4 // indirect + github.com/Microsoft/go-winio v0.6.2 github.com/jpillora/s3 v1.1.4 + github.com/yusufpapurcu/wmi v1.2.4 +) + +require ( + github.com/go-ole/go-ole v1.2.6 // indirect + golang.org/x/sys v0.10.0 // indirect ) diff --git a/go.sum b/go.sum index c440b5b..343a767 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,15 @@ -github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d h1:G0m3OIz70MZUWq3EgK3CesDbo8upS2Vm9/P3FtgI+Jk= -github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= -github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= -github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/jpillora/s3 v1.1.4 h1:YCCKDWzb/Ye9EBNd83ATRF/8wPEy0xd43Rezb6u6fzc= github.com/jpillora/s3 v1.1.4/go.mod h1:yedE603V+crlFi1Kl/5vZJaBu9pUzE9wvKegU/lF2zs= github.com/smartystreets/assertions v1.0.1 h1:voD4ITNjPL5jjBfgR/r8fPIIBrliWrWHeiJApdr3r4w= github.com/smartystreets/assertions v1.0.1/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM= github.com/smartystreets/gunit v1.1.3 h1:32x+htJCu3aMswhPw3teoJ+PnWPONqdNgaGs6Qt8ZaU= github.com/smartystreets/gunit v1.1.3/go.mod h1:EH5qMBab2UclzXUcpR8b93eHsIlp9u+pDQIRp5DZNzQ= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/graceful.go b/graceful.go index dfd65b0..160ae53 100644 --- a/graceful.go +++ b/graceful.go @@ -1,8 +1,7 @@ package overseer -//overseer listeners and connections allow graceful -//restarts by tracking when all connections from a listener -//have been closed +// overseer listeners and connections allow graceful restarts by tracking when all connections +// from a listener have been closed import ( "net" @@ -18,7 +17,7 @@ func newOverseerListener(l net.Listener) *overseerListener { } } -//gracefully closing net.Listener +// gracefully closing net.Listener type overseerListener struct { net.Listener closeError error @@ -27,12 +26,26 @@ type overseerListener struct { } func (l *overseerListener) Accept() (net.Conn, error) { - conn, err := l.Listener.(*net.TCPListener).AcceptTCP() - if err != nil { - return nil, err + var conn net.Conn + + // Try to convert the listener to TCPListener for better connection control + if tcpL, ok := l.Listener.(*net.TCPListener); ok { + tcpConn, tcpErr := tcpL.AcceptTCP() + if tcpErr != nil { + return nil, tcpErr + } + tcpConn.SetKeepAlive(true) // see http.tcpKeepAliveListener + tcpConn.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener + conn = tcpConn + } else { + // For non-TCP listeners, use standard Accept + standardConn, standardErr := l.Listener.Accept() + if standardErr != nil { + return nil, standardErr + } + conn = standardConn } - conn.SetKeepAlive(true) // see http.tcpKeepAliveListener - conn.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener + uconn := overseerConn{ Conn: conn, wg: &l.wg, @@ -44,18 +57,18 @@ func (l *overseerListener) Accept() (net.Conn, error) { case <-l.closeByForce: uconn.Close() case <-uconn.closed: - //closed manually + // closed manually } }() l.wg.Add(1) return uconn, nil } -//non-blocking trigger close +// non-blocking trigger close func (l *overseerListener) release(timeout time.Duration) { - //stop accepting connections - release fd + // stop accepting connections - release fd l.closeError = l.Listener.Close() - //start timer, close by force if deadline not met + // start timer, close by force if deadline not met waited := make(chan bool) go func() { l.wg.Wait() @@ -71,7 +84,7 @@ func (l *overseerListener) release(timeout time.Duration) { }() } -//blocking wait for close +// Close after blocking wait func (l *overseerListener) Close() error { l.wg.Wait() return l.closeError @@ -79,12 +92,17 @@ func (l *overseerListener) Close() error { func (l *overseerListener) File() *os.File { // returns a dup(2) - FD_CLOEXEC flag *not* set - tl := l.Listener.(*net.TCPListener) - fl, _ := tl.File() - return fl + if tcpL, ok := l.Listener.(*net.TCPListener); ok { + fl, _ := tcpL.File() + return fl + } + + // For non-TCP listeners, return nil + // This is safe because Windows version doesn't use ExtraFiles feature + return nil } -//notifying on close net.Conn +// notifying on close net.Conn type overseerConn struct { net.Conn wg *sync.WaitGroup diff --git a/overseer.go b/overseer.go index 8eb400c..8be17df 100644 --- a/overseer.go +++ b/overseer.go @@ -1,5 +1,4 @@ -// Package overseer implements daemonizable -// self-upgrading binaries in Go (golang). +// Package overseer implements daemonizable self-upgrading binaries in Go (golang). package overseer import ( @@ -25,44 +24,52 @@ const ( // Config defines overseer's run-time configuration type Config struct { - //Required will prevent overseer from fallback to running - //running the program in the main process on failure. + // Required will prevent overseer from fallback to running the program in the main process on failure. Required bool - //Program's main function + + // Program's main function Program func(state State) - //Program's zero-downtime socket listening address (set this or Addresses) + + // Address for program's zero-downtime socket listening (set this or Addresses) Address string - //Program's zero-downtime socket listening addresses (set this or Address) + + // Addresses ofr program's zero-downtime socket listening (set this or Address) Addresses []string - //RestartSignal will manually trigger a graceful restart. Defaults to SIGUSR2. + + // RestartSignal will manually trigger a graceful restart. Defaults to SIGUSR2. RestartSignal os.Signal - //TerminateTimeout controls how long overseer should - //wait for the program to terminate itself. After this - //timeout, overseer will issue a SIGKILL. + + // TerminateTimeout controls how long overseer should wait for the program to terminate itself. + // After this timeout, overseer will issue a SIGKILL. TerminateTimeout time.Duration - //MinFetchInterval defines the smallest duration between Fetch()s. - //This helps to prevent unwieldy fetch.Interfaces from hogging - //too many resources. Defaults to 1 second. + + // MinFetchInterval defines the smallest duration between Fetch()s. This helps to prevent unwieldy fetch. + // Interfaces from hogging too many resources. Defaults to 1 second. MinFetchInterval time.Duration - //PreUpgrade runs after a binary has been retrieved, user defined checks - //can be run here and returning an error will cancel the upgrade. + + // PreUpgrade runs after a binary has been retrieved, user defined checks can be run here and returning + // an error will cancel the upgrade. PreUpgrade func(tempBinaryPath string) error - //Debug enables all [overseer] logs. + + // Debug enables all [overseer] logs. Debug bool - //NoWarn disables warning [overseer] logs. + + // NoWarn disables warning [overseer] logs. NoWarn bool - //NoRestart disables all restarts, this option essentially converts - //the RestartSignal into a "ShutdownSignal". + + // NoRestart disables all restarts, this option essentially converts the RestartSignal into a "ShutdownSignal". NoRestart bool - //NoRestartAfterFetch disables automatic restarts after each upgrade. - //Though manual restarts using the RestartSignal can still be performed. + + // NoRestartAfterFetch disables automatic restarts after each upgrade. + // Though manual restarts using the RestartSignal can still be performed. NoRestartAfterFetch bool - //Fetcher will be used to fetch binaries. + + // Fetcher will be used to fetch binaries. Fetcher fetcher.Interface } func validate(c *Config) error { - //validate + // validate if c.Program == nil { return errors.New("overseer.Config.Program required") } @@ -86,15 +93,13 @@ func validate(c *Config) error { return nil } -//RunErr allows manual handling of any -//overseer errors. +// RunErr allows manual handling of any overseer errors. func RunErr(c Config) error { return runErr(&c) } -//Run executes overseer, if an error is -//encountered, overseer fallsback to running -//the program directly (unless Required is set). +// Run executes overseer, if an error is encountered, overseer falls-back to running +// the program directly (unless Config.Required is set). func Run(c Config) { err := runErr(&c) if err != nil { @@ -109,14 +114,14 @@ func Run(c Config) { os.Exit(0) } -//sanityCheck returns true if a check was performed +// sanityCheck returns true if a check was performed func sanityCheck() bool { - //sanity check + // sanity check if token := os.Getenv(envBinCheck); token != "" { fmt.Fprint(os.Stdout, token) return true } - //legacy sanity check using old env var + // legacy sanity check using old env var if token := os.Getenv(envBinCheckLegacy); token != "" { fmt.Fprint(os.Stdout, token) return true @@ -124,26 +129,24 @@ func sanityCheck() bool { return false } -//SanityCheck manually runs the check to ensure this binary -//is compatible with overseer. This tries to ensure that a restart -//is never performed against a bad binary, as it would require -//manual intervention to rectify. This is automatically done -//on overseer.Run() though it can be manually run prior whenever -//necessary. +// SanityCheck manually runs the check to ensure this binary is compatible with overseer. +// This tries to ensure that a restart is never performed against a bad binary, +// as it would require manual intervention to rectify. This is automatically done +// on overseer.Run() though it can be manually run prior whenever necessary. func SanityCheck() { if sanityCheck() { os.Exit(0) } } -//abstraction over master/slave +// abstraction over master/slave var currentProcess interface { triggerRestart() run() error } func runErr(c *Config) error { - //os not supported + // os not supported if !supported { return fmt.Errorf("os (%s) not supported", runtime.GOOS) } @@ -153,7 +156,7 @@ func runErr(c *Config) error { if sanityCheck() { return nil } - //run either in master or slave mode + // run either in master or slave mode if os.Getenv(envIsSlave) == "1" { currentProcess = &slave{Config: c} } else { @@ -162,15 +165,15 @@ func runErr(c *Config) error { return currentProcess.run() } -//Restart programmatically triggers a graceful restart. If NoRestart -//is enabled, then this will essentially be a graceful shutdown. +// Restart programmatically triggers a graceful restart. If Config.NoRestart is enabled, +// then this will essentially be a graceful shutdown. func Restart() { if currentProcess != nil { currentProcess.triggerRestart() } } -//IsSupported returns whether overseer is supported on the current OS. +// IsSupported returns whether overseer is supported on the current OS. func IsSupported() bool { return supported } diff --git a/proc_master.go b/proc_master.go index 2fd1451..63dd9d3 100644 --- a/proc_master.go +++ b/proc_master.go @@ -5,16 +5,15 @@ import ( "crypto/rand" "crypto/sha1" "encoding/hex" + "errors" "fmt" "io" "log" - "net" "os" "os/exec" "os/signal" "path/filepath" "runtime" - "strconv" "sync" "syscall" "time" @@ -22,7 +21,7 @@ import ( var tmpBinPath = filepath.Join(os.TempDir(), "overseer-"+token()+extension()) -//a overseer master process +// an overseer master process type master struct { *Config slaveID int @@ -39,6 +38,8 @@ type master struct { descriptorsReleased chan bool signalledAt time.Time printCheckUpdate bool + pipeNames []string // Windows pipe names array (one pipe per address) + pipeCancel func() // Function to cancel named pipe listening } func (mp *master) run() error { @@ -53,7 +54,7 @@ func (mp *master) run() error { } } mp.setupSignalling() - if err := mp.retreiveFileDescriptors(); err != nil { + if err := mp.retrieveFileDescriptors(); err != nil { return err } if mp.Config.Fetcher != nil { @@ -61,11 +62,19 @@ func (mp *master) run() error { mp.fetch() go mp.fetchLoop() } + + // Close named pipe when process exits + defer func() { + if mp.pipeCancel != nil { + mp.pipeCancel() + } + }() + return mp.forkLoop() } func (mp *master) checkBinary() error { - //get path to binary and confirm its writable + // get path to binary and confirm its writable binPath, err := os.Executable() if err != nil { return fmt.Errorf("failed to find binary path (%s)", err) @@ -76,19 +85,19 @@ func (mp *master) checkBinary() error { } else if info.Size() == 0 { return fmt.Errorf("binary file is empty") } else { - //copy permissions + // copy permissions mp.binPerms = info.Mode() } f, err := os.Open(binPath) if err != nil { return fmt.Errorf("cannot read binary (%s)", err) } - //initial hash of file + // initial hash of file hash := sha1.New() io.Copy(hash, f) mp.binHash = hash.Sum(nil) f.Close() - //test bin<->tmpbin moves + // test bin<->tmpbin moves if mp.Config.Fetcher != nil { if err := move(tmpBinPath, mp.binPath); err != nil { return fmt.Errorf("cannot move binary (%s)", err) @@ -101,10 +110,10 @@ func (mp *master) checkBinary() error { } func (mp *master) setupSignalling() { - //updater-forker comms + // updater-forker commands mp.restarted = make(chan bool) mp.descriptorsReleased = make(chan bool) - //read all master process signals + // read all master process signals signals := make(chan os.Signal) signal.Notify(signals) go func() { @@ -116,26 +125,23 @@ func (mp *master) setupSignalling() { func (mp *master) handleSignal(s os.Signal) { if s == mp.RestartSignal { - //user initiated manual restart + // user initiated manual restart go mp.triggerRestart() } else if s.String() == "child exited" { // will occur on every restart, ignore it } else - //**during a restart** a SIGUSR1 signals - //to the master process that, the file - //descriptors have been released + // **during a restart** a SIGUSR1 signals to the master process that, the file descriptors have been released if mp.awaitingUSR1 && s == SIGUSR1 { mp.debugf("signaled, sockets ready") mp.awaitingUSR1 = false mp.descriptorsReleased <- true } else - //while the slave process is running, proxy - //all signals through + // while the slave process is running, proxy all signals through if mp.slaveCmd != nil && mp.slaveCmd.Process != nil { mp.debugf("proxy signal (%s)", s) mp.sendSignal(s) } else - //otherwise if not running, kill on CTRL+c + // otherwise if not running, kill on CTRL+c if s == os.Interrupt { mp.debugf("interupt with no slave") os.Exit(1) @@ -153,30 +159,7 @@ func (mp *master) sendSignal(s os.Signal) { } } -func (mp *master) retreiveFileDescriptors() error { - mp.slaveExtraFiles = make([]*os.File, len(mp.Config.Addresses)) - for i, addr := range mp.Config.Addresses { - a, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return fmt.Errorf("Invalid address %s (%s)", addr, err) - } - l, err := net.ListenTCP("tcp", a) - if err != nil { - return err - } - f, err := l.File() - if err != nil { - return fmt.Errorf("Failed to retreive fd for: %s (%s)", addr, err) - } - if err := l.Close(); err != nil { - return fmt.Errorf("Failed to close listener for: %s (%s)", addr, err) - } - mp.slaveExtraFiles[i] = f - } - return nil -} - -//fetchLoop is run in a goroutine +// fetchLoop is run in a goroutine func (mp *master) fetchLoop() { min := mp.Config.MinFetchInterval time.Sleep(min) @@ -310,30 +293,30 @@ func (mp *master) fetch() { func (mp *master) triggerRestart() { if mp.restarting { mp.debugf("already graceful restarting") - return //skip + return // skip } else if mp.slaveCmd == nil || mp.restarting { mp.debugf("no slave process") - return //skip + return // skip } mp.debugf("graceful restart triggered") mp.restarting = true mp.awaitingUSR1 = true mp.signalledAt = time.Now() - mp.sendSignal(mp.Config.RestartSignal) //ask nicely to terminate + mp.sendSignal(mp.Config.RestartSignal) // ask nicely to terminate select { case <-mp.restarted: - //success + // success mp.debugf("restart success") case <-time.After(mp.TerminateTimeout): - //times up mr. process, we did ask nicely! + // times up mr. process, we did ask nicely! mp.debugf("graceful timeout, forcing exit") mp.sendSignal(os.Kill) } } -//not a real fork +// not a real fork func (mp *master) forkLoop() error { - //loop, restart command + // loop, restart command for { if err := mp.fork(); err != nil { return err @@ -344,68 +327,74 @@ func (mp *master) forkLoop() error { func (mp *master) fork() error { mp.debugf("starting %s", mp.binPath) cmd := exec.Command(mp.binPath) - //mark this new process as the "active" slave process. - //this process is assumed to be holding the socket files. + + // mark this new process as the "active" slave process. + // this process is assumed to be holding the socket files. mp.slaveCmd = cmd mp.slaveID++ - //provide the slave process with some state - e := os.Environ() - e = append(e, envBinID+"="+hex.EncodeToString(mp.binHash)) - e = append(e, envBinPath+"="+mp.binPath) - e = append(e, envSlaveID+"="+strconv.Itoa(mp.slaveID)) - e = append(e, envIsSlave+"=1") - e = append(e, envNumFDs+"="+strconv.Itoa(len(mp.slaveExtraFiles))) - cmd.Env = e - //inherit master args/stdfiles + + // provide the slave process with some state + cmd.Env = mp.retrieveSlaveEnviron() + + // inherit master args and std files cmd.Args = os.Args cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - //include socket files - cmd.ExtraFiles = mp.slaveExtraFiles + + // include socket files + // Windows doesn't support ExtraFiles, Windows version uses named pipes instead + if runtime.GOOS != "windows" { + cmd.ExtraFiles = mp.slaveExtraFiles + } + if err := cmd.Start(); err != nil { - return fmt.Errorf("Failed to start slave process: %s", err) + return fmt.Errorf("failed to start slave process: %s", err) } - //was scheduled to restart, notify success + // was scheduled to restart, notify success if mp.restarting { mp.restartedAt = time.Now() mp.restarting = false mp.restarted <- true } - //convert wait into channel - cmdwait := make(chan error) + + // convert wait into channel + cmdWait := make(chan error) go func() { - cmdwait <- cmd.Wait() + cmdWait <- cmd.Wait() }() - //wait.... + + // wait.... select { - case err := <-cmdwait: - //program exited before releasing descriptors - //proxy exit code out to master + case err := <-cmdWait: + // program exited before releasing descriptors proxy exit code out to master code := 0 if err != nil { code = 1 - if exiterr, ok := err.(*exec.ExitError); ok { - if status, ok := exiterr.Sys().(syscall.WaitStatus); ok { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { code = status.ExitStatus() + mp.debugf("prog exited with status: %d, error: %v", code, exitErr) + } else { + mp.debugf("prog exited with status: %d, but could not get syscall.WaitStatus", code) } + } else { + mp.debugf("prog exited with error (not ExitError): %v", err) } } mp.debugf("prog exited with %d", code) - //if a restarts are disabled or if it was an - //unexpected crash, proxy this exit straight - //through to the main process + // if a restarts are disabled or if it was an unexpected crash, + // proxy this exit straight through to the main process if mp.NoRestart || !mp.restarting { + mp.debugf("shutdown: NoRestart=%v, restarting=%v, proceeding with exit(%d)", mp.NoRestart, mp.restarting, code) os.Exit(code) } case <-mp.descriptorsReleased: - //if descriptors are released, the program - //has yielded control of its sockets and - //a parallel instance of the program can be - //started safely. it should serve state.Listeners - //to ensure downtime is kept at <1sec. The previous - //cmd.Wait() will still be consumed though the - //result will be discarded. + // if descriptors are released, the program has yielded control of its sockets and a parallel instance of + // the program can be started safely. it should serve state.Listeners to ensure downtime is kept at <1sec. + // The previous cmd.Wait() will still be consumed though the result will be discarded. + mp.debugf("descriptors released, starting new process") } return nil } diff --git a/proc_master_others.go b/proc_master_others.go new file mode 100644 index 0000000..a3bc6b9 --- /dev/null +++ b/proc_master_others.go @@ -0,0 +1,45 @@ +//go:build !windows + +package overseer + +import ( + "encoding/hex" + "fmt" + "net" + "os" + "strconv" +) + +func (mp *master) retrieveFileDescriptors() error { + mp.slaveExtraFiles = make([]*os.File, len(mp.Config.Addresses)) + for i, addr := range mp.Config.Addresses { + a, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return fmt.Errorf("invalid address %s (%s)", addr, err) + } + l, err := net.ListenTCP("tcp", a) + if err != nil { + return err + } + f, err := l.File() + if err != nil { + return fmt.Errorf("failed to retrieve fd for: %s (%s)", addr, err) + } + if err := l.Close(); err != nil { + return fmt.Errorf("failed to close listener for: %s (%s)", addr, err) + } + mp.slaveExtraFiles[i] = f + } + return nil +} + +// provide the slave process with some state +func (mp *master) retrieveSlaveEnviron() []string { + e := os.Environ() + e = append(e, envBinID+"="+hex.EncodeToString(mp.binHash)) + e = append(e, envBinPath+"="+mp.binPath) + e = append(e, envSlaveID+"="+strconv.Itoa(mp.slaveID)) + e = append(e, envIsSlave+"=1") + e = append(e, envNumFDs+"="+strconv.Itoa(len(mp.slaveExtraFiles))) + return e +} diff --git a/proc_master_windows.go b/proc_master_windows.go new file mode 100644 index 0000000..ceef63e --- /dev/null +++ b/proc_master_windows.go @@ -0,0 +1,322 @@ +//go:build windows + +package overseer + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "os" + "os/signal" + "strconv" + "strings" + "sync" + "syscall" + "time" +) + +// forwarder is used to forward an incoming net.Conn to another actual handler. +// Either local (in parent case) or via a Dialer (in child case). +type forwarder struct { + handle func(net.Conn) + close func() + wg sync.WaitGroup +} + +func (f *forwarder) closeAndWait() { + f.close() + f.wg.Wait() +} + +type wgCloser struct { + net.Conn + done func() +} + +func (wc wgCloser) Close() error { + err := wc.Conn.Close() + if err == nil || !errors.Is(err, net.ErrClosed) { + wc.done() + } + return err +} + +type childRequest struct { + Addresses []string `json:"addresses"` +} + +func (mp *master) retrieveFileDescriptors() (err error) { + mp.slaveExtraFiles = make([]*os.File, len(mp.Config.Addresses)) + listeners := make([]net.Listener, 0, len(mp.Config.Addresses)) + + // Create unique named pipes for each address + pipeNames := make([]string, len(mp.Config.Addresses)) + pipeListeners := make([]net.Listener, len(mp.Config.Addresses)) + + // Track all active connections for cleanup + var activePipeConns sync.Map + + // Create close function to clean up all connections when process exits + closePipes := func() { + mp.debugf("Closing all named pipe connections") + activePipeConns.Range(func(key, value interface{}) bool { + if conn, ok := value.(net.Conn); ok { + conn.Close() + } + return true + }) + } + + // Save cancel function + originalCancel := mp.pipeCancel + mp.pipeCancel = func() { + if originalCancel != nil { + originalCancel() + } + closePipes() + } + + // Listen for system termination signals + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + mp.debugf("Received termination signal, cleaning up resources") + closePipes() + }() + + // Create all pipe listeners + for i, addr := range mp.Config.Addresses { + // Create unique pipe name for each address + pipeName := fmt.Sprintf("overseer_%d_%s_%d.pipe", os.Getpid(), + hex.EncodeToString([]byte(addr))[0:8], i) + + // Create named pipe listener + pipeListener, err := listenPipe(pipeName) + if err != nil { + // Close already created pipes + for j := 0; j < i; j++ { + pipeListeners[j].Close() + } + return fmt.Errorf("failed to create named pipe for address %s: %v", addr, err) + } + + pipeNames[i] = pipeName + pipeListeners[i] = pipeListener + mp.debugf("Created named pipe for address %s", addr) + } + + // Save pipe names to pass as environment variables to child process + mp.pipeNames = pipeNames + + // Create cancellable context + ctx, cancel := context.WithCancel(context.Background()) + mp.pipeCancel = cancel + + // Create TCP listeners and corresponding forwarders for each address + for i, addr := range mp.Config.Addresses { + // Create TCP listener + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("failed to listen on address %s: %v", addr, err) + } + listeners = append(listeners, ln) + mp.debugf("Created TCP listener for address %s", addr) + + // Create connection channel + connCh := make(chan net.Conn, 10) + forwarderCh := make(chan *forwarder, 1) + + // Start external listener + go mp.runExternalListener(ln, connCh) + + // Start internal listener + go mp.runInternalListener(connCh, forwarderCh, func() { + mp.debugf("Internal listener for address %s closed", addr) + }) + + // Wait in background for child process to connect to the corresponding pipe + go func(i int, addr string, pipeListener net.Listener, forwarderCh chan *forwarder) { + defer pipeListener.Close() + mp.debugf("Waiting for child process to connect to named pipe for address %s", addr) + + // Listen for context cancellation + go func() { + <-ctx.Done() + pipeListener.Close() + }() + + // Accept connection + conn, err := pipeListener.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + mp.debugf("Error accepting pipe connection for address %s: %v", addr, err) + return + } + + // Store connection for cleanup + connID := fmt.Sprintf("pipe-%s-%d", addr, i) + activePipeConns.Store(connID, conn) + + mp.debugf("Accepted pipe connection from child process for address %s", addr) + + // Create forwarder + fw := &forwarder{ + close: func() { + conn.Close() + activePipeConns.Delete(connID) + }, + } + + // Set handler function + fw.handle = func(clientConn net.Conn) { + // Add to wait group + fw.wg.Add(1) + wc := wgCloser{ + Conn: clientConn, + done: fw.wg.Done, + } + + // Start data forwarding - directly use pipe connection + go proxyConnection(conn, wc, mp.debugf) + } + + // Send forwarder to internal listener + forwarderCh <- fw + + // Keep connection open until main process closes + <-ctx.Done() + conn.Close() + activePipeConns.Delete(connID) + }(i, addr, pipeListeners[i], forwarderCh) + } + + return nil +} + +// proxyConnection transfers data bidirectionally between two connections +func proxyConnection(conn1, conn2 net.Conn, debug func(string, ...interface{})) { + // Create control channel + done := make(chan struct{}) + + // From conn1 to conn2 + go func() { + _, err := io.Copy(conn2, conn1) + if err != nil { + if err != io.EOF && !isConnectionClosed(err) { + debug("Data transfer error: %v", err) + } + } + conn2.Close() + close(done) + }() + + // From conn2 to conn1 + go func() { + _, err := io.Copy(conn1, conn2) + if err != nil { + if err != io.EOF && !isConnectionClosed(err) { + debug("Data transfer error: %v", err) + } + } + conn1.Close() + }() +} + +// isConnectionClosed determines if an error is a connection closed error +func isConnectionClosed(err error) bool { + if err == nil { + return false + } + if errors.Is(err, net.ErrClosed) { + return true + } + if errors.Is(err, io.EOF) { + return true + } + if strings.Contains(err.Error(), "use of closed network connection") { + return true + } + if strings.Contains(err.Error(), "connection reset by peer") { + return true + } + if strings.Contains(err.Error(), "broken pipe") { + return true + } + return false +} + +func (mp *master) runExternalListener(ln net.Listener, ch chan net.Conn) { + defer close(ch) + for { + rw, err := ln.Accept() + + if err != nil { + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + time.Sleep(5 * time.Millisecond) + continue + } + return + } + ch <- rw + } +} + +func (mp *master) runInternalListener(connCh chan net.Conn, forwarderCh chan *forwarder, done func()) { + defer done() + + var current *forwarder + defer func() { + if current != nil { + current.closeAndWait() + } + }() + + for { + select { + case conn, ok := <-connCh: + if !ok { + // connCh closed, we're shutting down + return + } + if current != nil { + current.handle(conn) + } else { + // No child process, close connection directly + conn.Close() + } + case fw, ok := <-forwarderCh: + if !ok { + // forwarderCh closed, we're shutting down + return + } + if current != nil { + current.closeAndWait() + } + current = fw + } + } +} + +// provide the slave process with some state +func (mp *master) retrieveSlaveEnviron() []string { + e := os.Environ() + e = append(e, envBinID+"="+hex.EncodeToString(mp.binHash)) + e = append(e, envBinPath+"="+mp.binPath) + e = append(e, envSlaveID+"="+strconv.Itoa(mp.slaveID)) + e = append(e, envIsSlave+"=1") + + // Add all pipe names to environment variables + for i, pipeName := range mp.pipeNames { + e = append(e, fmt.Sprintf("OVERSEER_PIPE_NAME_%d=%s", i, pipeName)) + } + e = append(e, fmt.Sprintf("OVERSEER_PIPE_COUNT=%d", len(mp.pipeNames))) + + return e +} diff --git a/proc_slave.go b/proc_slave.go index 4f64e14..311324e 100644 --- a/proc_slave.go +++ b/proc_slave.go @@ -1,46 +1,48 @@ package overseer import ( - "fmt" "log" "net" "os" "os/signal" - "strconv" "time" ) var ( - //DisabledState is a placeholder state for when - //overseer is disabled and the program function - //is run manually. + // DisabledState is a placeholder state for when overseer is disabled and the program function is run manually. DisabledState = State{Enabled: false} ) // State contains the current run-time state of overseer type State struct { - //whether overseer is running enabled. When enabled, + // whether overseer is running enabled. When enabled, //this program will be running in a child process and //overseer will perform rolling upgrades. Enabled bool - //ID is a SHA-1 hash of the current running binary + + // ID is an SHA-1 hash of the current running binary ID string - //StartedAt records the start time of the program + + // StartedAt records the start time of the program StartedAt time.Time - //Listener is the first net.Listener in Listeners + + // Listener is the first net.Listener in Listeners Listener net.Listener - //Listeners are the set of acquired sockets by the master - //process. These are all passed into this program in the - //same order they are specified in Config.Addresses. + + // Listeners are the set of acquired sockets by the master process. + // These are all passed into this program in the same order they are specified in Config.Addresses. Listeners []net.Listener - //Program's first listening address + + // Address for program's first listening Address string - //Program's listening addresses + + // Addresses for program's listening Addresses []string - //GracefulShutdown will be filled when its time to perform - //a graceful shutdown. + + // GracefulShutdown will be filled when it's time to perform a graceful shutdown. GracefulShutdown chan bool - //Path of the binary currently being executed + + // BinPath is path of the binary currently being executed BinPath string } @@ -78,30 +80,6 @@ func (sp *slave) run() error { return nil } -func (sp *slave) initFileDescriptors() error { - //inspect file descriptors - numFDs, err := strconv.Atoi(os.Getenv(envNumFDs)) - if err != nil { - return fmt.Errorf("invalid %s integer", envNumFDs) - } - sp.listeners = make([]*overseerListener, numFDs) - sp.state.Listeners = make([]net.Listener, numFDs) - for i := 0; i < numFDs; i++ { - f := os.NewFile(uintptr(3+i), "") - l, err := net.FileListener(f) - if err != nil { - return fmt.Errorf("failed to inherit file descriptor: %d", i) - } - u := newOverseerListener(l) - sp.listeners[i] = u - sp.state.Listeners[i] = u - } - if len(sp.state.Listeners) > 0 { - sp.state.Listener = sp.state.Listeners[0] - } - return nil -} - func (sp *slave) watchSignal() { signals := make(chan os.Signal) signal.Notify(signals, sp.Config.RestartSignal) @@ -109,23 +87,22 @@ func (sp *slave) watchSignal() { <-signals signal.Stop(signals) sp.debugf("graceful shutdown requested") - //master wants to restart, + // master wants to restart, close(sp.state.GracefulShutdown) - //release any sockets and notify master + // release any sockets and notify master if len(sp.listeners) > 0 { - //perform graceful shutdown + // perform graceful shutdown for _, l := range sp.listeners { l.release(sp.Config.TerminateTimeout) } - //signal release of held sockets, allows master to start - //a new process before this child has actually exited. - //early restarts not supported with restarts disabled. + // signal release of held sockets, allows master to start a new process + // before this child has actually exited. early restarts not supported with restarts disabled. if !sp.NoRestart { sp.masterProc.Signal(SIGUSR1) } - //listeners should be waiting on connections to close... + // listeners should be waiting on connections to close... } - //start death-timer + // start death-timer go func() { time.Sleep(sp.Config.TerminateTimeout) sp.debugf("timeout. forceful shutdown") diff --git a/proc_slave_others.go b/proc_slave_others.go index 2a028e1..a981f7e 100644 --- a/proc_slave_others.go +++ b/proc_slave_others.go @@ -1,10 +1,12 @@ -// +build !windows +//go:build !windows package overseer import ( "fmt" + "net" "os" + "strconv" "syscall" "time" ) @@ -17,9 +19,9 @@ func (sp *slave) watchParent() error { } sp.masterProc = proc go func() { - //send signal 0 to master process forever + // send signal 0 to master process forever for { - //should not error as long as the process is alive + // should not error as long as the process is alive if err := sp.masterProc.Signal(syscall.Signal(0)); err != nil { os.Exit(1) } @@ -32,3 +34,27 @@ func (sp *slave) watchParent() error { func overwrite(dst, src string) error { return move(dst, src) } + +func (sp *slave) initFileDescriptors() error { + // inspect file descriptors + numFDs, err := strconv.Atoi(os.Getenv(envNumFDs)) + if err != nil { + return fmt.Errorf("invalid %s integer", envNumFDs) + } + sp.listeners = make([]*overseerListener, numFDs) + sp.state.Listeners = make([]net.Listener, numFDs) + for i := 0; i < numFDs; i++ { + f := os.NewFile(uintptr(3+i), "") + l, err := net.FileListener(f) + if err != nil { + return fmt.Errorf("failed to inherit file descriptor: %d", i) + } + u := newOverseerListener(l) + sp.listeners[i] = u + sp.state.Listeners[i] = u + } + if len(sp.state.Listeners) > 0 { + sp.state.Listener = sp.state.Listeners[0] + } + return nil +} diff --git a/proc_slave_windows.go b/proc_slave_windows.go index 1ad5b9a..08b5234 100644 --- a/proc_slave_windows.go +++ b/proc_slave_windows.go @@ -1,22 +1,26 @@ -// +build windows +//go:build windows package overseer import ( "context" "fmt" + "io" + "net" "os" + "strconv" "strings" + "sync" "time" - "github.com/StackExchange/wmi" + "github.com/yusufpapurcu/wmi" ) var ( Timeout = 3 * time.Second ) -type Win32_Process struct { +type Win32Process struct { Name string ExecutablePath *string CommandLine *string @@ -55,6 +59,268 @@ type Win32_Process struct { WorkingSetSize uint64 } +// Implement connection to parent process via named pipes +func (sp *slave) initFileDescriptors() error { + // Get the number of pipes created by parent process + pipeCountStr := os.Getenv("OVERSEER_PIPE_COUNT") + if pipeCountStr == "" { + return fmt.Errorf("missing pipe count environment variable") + } + + pipeCount, err := strconv.Atoi(pipeCountStr) + if err != nil { + return fmt.Errorf("invalid pipe count: %v", err) + } + + // Verify that pipe count matches address count + if pipeCount != len(sp.Config.Addresses) { + return fmt.Errorf("pipe count (%d) does not match address count (%d)", + pipeCount, len(sp.Config.Addresses)) + } + + // Create virtual listeners + sp.listeners = make([]*overseerListener, pipeCount) + sp.state.Listeners = make([]net.Listener, pipeCount) + sp.debugf("Preparing to create %d virtual listeners", pipeCount) + + // Connect to corresponding pipes and create listeners for each address + for i, addr := range sp.Config.Addresses { + // Get pipe name for this address + pipeName := os.Getenv(fmt.Sprintf("OVERSEER_PIPE_NAME_%d", i)) + if pipeName == "" { + return fmt.Errorf("missing pipe name for address %s", addr) + } + + sp.debugf("Trying to connect to pipe for address %s", addr) + + // Connect to pipe + conn, err := dialPipe(pipeName) + if err != nil { + return fmt.Errorf("failed to connect to pipe for address %s: %v", addr, err) + } + + // Create connection channel for this address + connCh := make(chan net.Conn, 10) + + // Create virtual listener + l := &pipeListener{ + addr: addr, + connCh: connCh, + } + + // Create overseerListener wrapper + u := newOverseerListener(l) + sp.listeners[i] = u + sp.state.Listeners[i] = u + + // Start goroutine to handle TCP connections + go func(addr string, conn net.Conn, connCh chan net.Conn) { + defer conn.Close() + + // Create connection termination channel + terminated := make(chan struct{}) + + // Goroutine to detect pipe connection status + go func() { + // Try simple read, EOF or error indicates the pipe has disconnected + buf := make([]byte, 1) + _, err := conn.Read(buf) + if err != nil { + if err == io.EOF { + sp.debugf("Detected pipe EOF, parent process may have exited: %s", addr) + } else { + sp.debugf("Detected pipe error, parent process may have exited: %v", err) + } + close(terminated) + } + }() + + // Connection loop + connectionAttempts := 0 + for { + select { + case <-terminated: + return + default: + // If too many consecutive failures, stop creating new connections + if connectionAttempts >= 3 { + sp.debugf("Too many consecutive connection failures for address %s, stopping attempts", addr) + return + } + + // Create pipe connection + pipeConn := &pipeConn{ + reader: conn, + writer: conn, + local: l.Addr(), + id: fmt.Sprintf("%s-conn", addr), + sp: sp, + done: make(chan struct{}), + } + + // Send to connection channel + select { + case connCh <- pipeConn: + // Wait for connection to complete or close + select { + case <-pipeConn.done: + // Check if closed due to EOF + if pipeConn.closedByEOF { + connectionAttempts++ + // Add delay after consecutive EOF errors to avoid generating excessive logs + time.Sleep(time.Duration(connectionAttempts*500) * time.Millisecond) + } else { + // Normal closure, reset attempt counter + connectionAttempts = 0 + } + case <-terminated: + return + } + case <-time.After(2 * time.Second): + // Timeout, pipe may be closed + connectionAttempts++ + return + case <-terminated: + return + } + } + } + }(addr, conn, connCh) + } + + // Set main listener + if len(sp.state.Listeners) > 0 { + sp.state.Listener = sp.state.Listeners[0] + sp.debugf("Setting first listener as main listener") + } + + return nil +} + +// pipeListener implements net.Listener interface +type pipeListener struct { + addr string + connCh chan net.Conn + closed bool + mu sync.Mutex +} + +func (l *pipeListener) Accept() (net.Conn, error) { + if conn, ok := <-l.connCh; ok { + return conn, nil + } + return nil, fmt.Errorf("listener closed") +} + +func (l *pipeListener) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + + if !l.closed { + l.closed = true + close(l.connCh) + } + + return nil +} + +func (l *pipeListener) Addr() net.Addr { + return pipeAddr(l.addr) +} + +// pipeAddr implements net.Addr interface +type pipeAddr string + +func (a pipeAddr) Network() string { return "pipe" } +func (a pipeAddr) String() string { return string(a) } + +// pipeConn implements net.Conn interface +type pipeConn struct { + reader io.Reader + writer io.Writer + local net.Addr + closed bool + closedByEOF bool // Mark as closed by EOF + closeMu sync.Mutex + id string + sp *slave + done chan struct{} // Signal channel, indicating connection completed/closed +} + +// Initialize pipeConn +func newPipeConn(reader io.Reader, writer io.Writer, local net.Addr, id string, sp *slave) *pipeConn { + return &pipeConn{ + reader: reader, + writer: writer, + local: local, + id: id, + sp: sp, + done: make(chan struct{}), + } +} + +func (c *pipeConn) Read(b []byte) (n int, err error) { + if c.closed { + return 0, net.ErrClosed + } + + n, err = c.reader.Read(b) + if err != nil && c.sp != nil { + if err == io.EOF { + c.closedByEOF = true // Mark as closed by EOF + } + } + + return n, err +} + +func (c *pipeConn) Write(b []byte) (n int, err error) { + if c.closed { + return 0, net.ErrClosed + } + + n, err = c.writer.Write(b) + return n, err +} + +func (c *pipeConn) Close() error { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if !c.closed { + c.closed = true + + // Notify connection has completed + select { + case <-c.done: // Already closed + default: + close(c.done) + } + } + + return nil +} + +func (c *pipeConn) LocalAddr() net.Addr { + return c.local +} + +func (c *pipeConn) RemoteAddr() net.Addr { + return pipeAddr("pipe-remote") +} + +func (c *pipeConn) SetDeadline(t time.Time) error { + return nil // Not supported +} + +func (c *pipeConn) SetReadDeadline(t time.Time) error { + return nil // Not supported +} + +func (c *pipeConn) SetWriteDeadline(t time.Time) error { + return nil // Not supported +} + func (sp *slave) watchParent() error { sp.masterPid = os.Getppid() proc, err := os.FindProcess(sp.masterPid) @@ -62,34 +328,61 @@ func (sp *slave) watchParent() error { return fmt.Errorf("master process: %s", err) } sp.masterProc = proc + sp.debugf("Found parent process, PID: %d", sp.masterPid) + go func() { - //send signal 0 to master process forever + sp.debugf("Starting monitoring of parent process (PID: %d) alive status", sp.masterPid) + // Periodically check if parent process is alive + failures := 0 for { - //should not error as long as the process is alive - if _, err := GetWin32Proc(int32(sp.masterPid)); err != nil { - os.Exit(1) + // Try to get process info via WMI + _, err := GetWin32Proc(int32(sp.masterPid)) + if err != nil { + failures++ + sp.debugf("Parent process detection failed (%d/3): %v", failures, err) + + // If it's a WMI error, try alternative methods to check the process + if failures == 1 { + // Try using os.FindProcess as a fallback + if _, err := os.FindProcess(sp.masterPid); err == nil { + // On Windows, FindProcess almost always succeeds, + // so we try to send a null signal to check if process exists + if err := sp.masterProc.Signal(os.Signal(nil)); err == nil { + sp.debugf("Verified parent process alive using alternative method") + failures = 0 // Reset failure count + } + } + } + + if failures >= 3 { + sp.debugf("Parent process has terminated, child process will exit") + os.Exit(1) + } + } else if failures > 0 { + sp.debugf("Parent process detection returned to normal") + failures = 0 } time.Sleep(2 * time.Second) } }() + return nil } -func GetWin32Proc(pid int32) ([]Win32_Process, error) { +func GetWin32Proc(pid int32) ([]Win32Process, error) { return GetWin32ProcWithContext(context.Background(), pid) } -func GetWin32ProcWithContext(ctx context.Context, pid int32) ([]Win32_Process, error) { - var dst []Win32_Process - query := fmt.Sprintf("WHERE ProcessId = %d", pid) - q := wmi.CreateQuery(&dst, query) - err := WMIQueryWithContext(ctx, q, &dst) +func GetWin32ProcWithContext(ctx context.Context, pid int32) ([]Win32Process, error) { + var dst []Win32Process + query := fmt.Sprintf("SELECT * FROM Win32_Process WHERE ProcessId = %d", pid) + err := WMIQueryWithContext(ctx, query, &dst) if err != nil { - return []Win32_Process{}, fmt.Errorf("could not get win32Proc: %s", err) + return []Win32Process{}, fmt.Errorf("could not get win32Proc: %s", err) } if len(dst) == 0 { - return []Win32_Process{}, fmt.Errorf("could not get win32Proc: empty") + return []Win32Process{}, fmt.Errorf("could not get win32Proc: empty") } return dst, nil @@ -104,7 +397,7 @@ func WMIQueryWithContext(ctx context.Context, query string, dst interface{}, con errChan := make(chan error, 1) go func() { - errChan <- wmi.Query(query, dst, connectServerArgs...) + errChan <- wmi.QueryNamespace(query, dst, "root\\CIMV2") }() select { diff --git a/socket_windows.go b/socket_windows.go new file mode 100644 index 0000000..fb45ad5 --- /dev/null +++ b/socket_windows.go @@ -0,0 +1,19 @@ +//go:build windows + +package overseer + +import ( + "net" + + "github.com/Microsoft/go-winio" +) + +// listenPipe creates a Windows named pipe listener +func listenPipe(name string) (net.Listener, error) { + return winio.ListenPipe(`\\.\pipe\`+name, nil) +} + +// dialPipe connects to a Windows named pipe +func dialPipe(name string) (net.Conn, error) { + return winio.DialPipe(`\\.\pipe\`+name, nil) +} diff --git a/sys_posix.go b/sys_posix.go index c0abf71..2096ecb 100644 --- a/sys_posix.go +++ b/sys_posix.go @@ -1,4 +1,4 @@ -// +build linux darwin freebsd +//go:build linux || darwin || freebsd package overseer diff --git a/sys_unsupported.go b/sys_unsupported.go index 4a64861..96d69ba 100644 --- a/sys_unsupported.go +++ b/sys_unsupported.go @@ -1,4 +1,4 @@ -// +build !linux,!darwin,!windows,!freebsd +//go:build !linux && !darwin && !windows && !freebsd package overseer diff --git a/sys_windows.go b/sys_windows.go index 32b39a2..f481f8b 100644 --- a/sys_windows.go +++ b/sys_windows.go @@ -1,4 +1,4 @@ -// +build windows +//go:build windows package overseer