@@ -23,6 +23,7 @@ import (
23
23
"net"
24
24
"os"
25
25
"os/exec"
26
+ "os/signal"
26
27
"path/filepath"
27
28
"strconv"
28
29
"strings"
@@ -1199,3 +1200,106 @@ func createValidConfig(t *testing.T, socketPath string) Config {
1199
1200
},
1200
1201
}
1201
1202
}
1203
+
1204
+ func TestSignalForwarding (t * testing.T ) {
1205
+ forwardedSignals := []os.Signal {
1206
+ syscall .SIGUSR1 ,
1207
+ syscall .SIGUSR2 ,
1208
+ syscall .SIGINT ,
1209
+ syscall .SIGTERM ,
1210
+ }
1211
+ ignoredSignals := []os.Signal {
1212
+ syscall .SIGHUP ,
1213
+ syscall .SIGQUIT ,
1214
+ }
1215
+
1216
+ cfg := Config {
1217
+ Debug : true ,
1218
+ KernelImagePath : filepath .Join (testDataPath , "vmlinux" ),
1219
+ SocketPath : "/tmp/TestSignalForwarding.sock" ,
1220
+ Drives : []models.Drive {
1221
+ {
1222
+ DriveID : String ("0" ),
1223
+ IsRootDevice : Bool (true ),
1224
+ IsReadOnly : Bool (false ),
1225
+ PathOnHost : String (testRootfs ),
1226
+ },
1227
+ },
1228
+ DisableValidation : true ,
1229
+ ForwardSignals : forwardedSignals ,
1230
+ }
1231
+ defer os .RemoveAll ("/tmp/TestSignalForwarding.sock" )
1232
+
1233
+ opClient := fctesting.MockClient {}
1234
+
1235
+ ctx := context .Background ()
1236
+ client := NewClient (cfg .SocketPath , fctesting .NewLogEntry (t ), true , WithOpsClient (& opClient ))
1237
+
1238
+ fd , err := net .Listen ("unix" , cfg .SocketPath )
1239
+ if err != nil {
1240
+ t .Fatalf ("unexpected error during creation of unix socket: %v" , err )
1241
+ }
1242
+ defer fd .Close ()
1243
+
1244
+ stdout := & bytes.Buffer {}
1245
+ stderr := & bytes.Buffer {}
1246
+ cmd := exec .Command (filepath .Join (testDataPath , "sigprint.sh" ))
1247
+ cmd .Stdout = stdout
1248
+ cmd .Stderr = stderr
1249
+ stdin , err := cmd .StdinPipe ()
1250
+ assert .NoError (t , err )
1251
+
1252
+ m , err := NewMachine (
1253
+ ctx ,
1254
+ cfg ,
1255
+ WithClient (client ),
1256
+ WithProcessRunner (cmd ),
1257
+ WithLogger (fctesting .NewLogEntry (t )),
1258
+ )
1259
+ if err != nil {
1260
+ t .Fatalf ("failed to create new machine: %v" , err )
1261
+ }
1262
+
1263
+ if err := m .startVMM (ctx ); err != nil {
1264
+ t .Fatalf ("error startVMM: %v" , err )
1265
+ }
1266
+ defer m .StopVMM ()
1267
+
1268
+ sigChan := make (chan os.Signal )
1269
+ signal .Notify (sigChan , ignoredSignals ... )
1270
+ defer func () {
1271
+ signal .Stop (sigChan )
1272
+ close (sigChan )
1273
+ }()
1274
+
1275
+ go func () {
1276
+ for sig := range sigChan {
1277
+ t .Logf ("received signal %v, ignoring" , sig )
1278
+ }
1279
+ }()
1280
+
1281
+ go func () {
1282
+ for _ , sig := range append (forwardedSignals , ignoredSignals ... ) {
1283
+ t .Logf ("sending signal %v to self" , sig )
1284
+ syscall .Kill (syscall .Getpid (), sig .(syscall.Signal ))
1285
+ }
1286
+
1287
+ // give the child process time to receive signals and flush pipes
1288
+ time .Sleep (1 * time .Second )
1289
+
1290
+ // terminate the signal printing process
1291
+ stdin .Write ([]byte ("q" ))
1292
+ }()
1293
+
1294
+ err = m .Wait (ctx )
1295
+ require .NoError (t , err , "wait returned an error" )
1296
+
1297
+ receivedSignals := []os.Signal {}
1298
+ for _ , sigStr := range strings .Split (strings .TrimSpace (stdout .String ()), "\n " ) {
1299
+ i , err := strconv .Atoi (sigStr )
1300
+ require .NoError (t , err , "expected numeric output" )
1301
+ receivedSignals = append (receivedSignals , syscall .Signal (i ))
1302
+ }
1303
+
1304
+ assert .ElementsMatch (t , forwardedSignals , receivedSignals )
1305
+ }
0 commit comments