Skip to content

Commit ae99ee3

Browse files
drewhlicopybara-github
authored andcommitted
Add support for forcing socket connection types for plugins.
PiperOrigin-RevId: 877693904
1 parent 8245bc5 commit ae99ee3

File tree

4 files changed

+103
-4
lines changed

4 files changed

+103
-4
lines changed

internal/cfg/cfg.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ timeout_in_seconds = 60
137137
socket_connections_dir = {{.socketConnectionsDir}}
138138
state_dir = {{.baseStateDir}}
139139
local_plugin_dir = {{.localPluginDir}}
140+
connection_type =
140141
141142
[ACS]
142143
endpoint =
@@ -439,6 +440,10 @@ type Plugin struct {
439440
StateDir string `ini:"state_dir,omitempty"`
440441
// LocalPluginDir defines the directory path where local plugins are installed.
441442
LocalPluginDir string `ini:"local_plugin_dir,omitempty"`
443+
// ConnectionType defines the connection type of the plugin. Only "tcp" and
444+
// "uds" are supported; anything else will be ignored and the default
445+
// behavior will be used.
446+
ConnectionType string `ini:"connection_type,omitempty"`
442447
}
443448

444449
// Unstable contains the configurations of Unstable section. No long term

internal/plugin/manager/pluginlauncher.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ func address(ctx context.Context, protocol, id string, policy retry.Policy) (str
252252
return retry.RunWithResponse(ctx, policy, f)
253253
}
254254

255-
// isUDSSupported returns true if UDS is supported on Windows. Instead of going
256-
// by version to figure out if UDS is supported, try listening on test address
257-
// using UDS, if it gets listener successfully consider UDS is supported.
258-
func isUDSSupported() bool {
255+
// defaultIsUDSSupported returns true if UDS is supported on Windows. Instead of
256+
// going by version to figure out if UDS is supported, try listening on test
257+
// address using UDS, if it gets listener successfully consider UDS is supported.
258+
func defaultIsUDSSupported() bool {
259259
if runtime.GOOS == "linux" {
260260
return true
261261
}

internal/plugin/manager/pluginmanager.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ const (
6262
var (
6363
// pluginManager is the instance of plugin manager.
6464
pluginManager *PluginManager
65+
66+
// isUDSSupported checks if UDS is supported on the host. This is stubbed out
67+
// for testing.
68+
isUDSSupported = defaultIsUDSSupported
6569
)
6670

6771
// PluginManager struct represents the plugins that plugin manager manages.
@@ -262,8 +266,21 @@ func InitPluginManager(ctx context.Context, instanceID string) (*PluginManager,
262266
pluginManager.pendingPluginRevisionsMu.Unlock()
263267
wg.Wait()
264268

269+
// TCP is always supported, so if the forced connection type is TCP, we can
270+
// return early.
271+
connType := cfg.Retrieve().Plugin.ConnectionType
272+
if connType == tcpProtocol {
273+
pluginManager.protocol = tcpProtocol
274+
return pluginManager, nil
275+
}
276+
277+
// If the connection type is not TCP, check if UDS is supported and set the
278+
// protocol to UDS if it is. Otherwise, fallback to TCP.
265279
if isUDSSupported() {
266280
pluginManager.protocol = udsProtocol
281+
} else {
282+
galog.Debugf("UDS is not supported, fallback to TCP")
283+
pluginManager.protocol = tcpProtocol
267284
}
268285
return pluginManager, nil
269286
}

internal/plugin/manager/pluginmanager_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,83 @@ func TestInitPluginManager(t *testing.T) {
422422
}
423423
}
424424

425+
func TestInitPluginManagerConfig(t *testing.T) {
426+
ctx := context.WithValue(context.Background(), client.OverrideConnection, &fakeACS{})
427+
stateDir := t.TempDir()
428+
addr := filepath.Join(t.TempDir(), "pluginA_revisionA.sock")
429+
430+
tests := []struct {
431+
name string
432+
isUDSSupported bool
433+
connectionType string
434+
expectedProtocol string
435+
}{
436+
{
437+
name: "uds_connection_type",
438+
isUDSSupported: true,
439+
connectionType: udsProtocol,
440+
expectedProtocol: udsProtocol,
441+
},
442+
{
443+
name: "tcp_connection_type",
444+
isUDSSupported: true,
445+
connectionType: tcpProtocol,
446+
expectedProtocol: tcpProtocol,
447+
},
448+
{
449+
name: "invalid_connection_type",
450+
isUDSSupported: true,
451+
connectionType: "invalid",
452+
expectedProtocol: udsProtocol,
453+
},
454+
{
455+
name: "uds_not_supported",
456+
isUDSSupported: false,
457+
connectionType: udsProtocol,
458+
expectedProtocol: tcpProtocol,
459+
},
460+
{
461+
name: "uds_not_supported_tcp_connection_type",
462+
isUDSSupported: false,
463+
connectionType: tcpProtocol,
464+
expectedProtocol: tcpProtocol,
465+
},
466+
{
467+
name: "uds_not_supported_invalid_connection_type",
468+
isUDSSupported: false,
469+
connectionType: "invalid",
470+
expectedProtocol: tcpProtocol,
471+
},
472+
}
473+
474+
for _, tc := range tests {
475+
t.Run(tc.name, func(t *testing.T) {
476+
config := fmt.Sprintf("[PluginConfig]\nstate_dir = %s\nconnection_type = %s\n[Core]\nacs_client = false\nsocket_connections_dir = %s", stateDir, tc.connectionType, filepath.Dir(addr))
477+
if err := cfg.Load([]byte(config)); err != nil {
478+
t.Fatalf("cfg.Load(nil) failed unexpectedly with error: %v", err)
479+
}
480+
481+
oldIsUDSSupported := isUDSSupported
482+
t.Cleanup(func() { isUDSSupported = oldIsUDSSupported })
483+
isUDSSupported = func() bool { return tc.isUDSSupported }
484+
485+
pm, err := InitPluginManager(ctx, "1234567890")
486+
if err != nil {
487+
t.Fatalf("InitPluginManager(ctx) failed unexpectedly with error: %v", err)
488+
}
489+
if pm.protocol != tc.expectedProtocol {
490+
t.Errorf("InitPluginManager(ctx) = protocol %q, want %q", pm.protocol, tc.expectedProtocol)
491+
}
492+
493+
t.Cleanup(func() {
494+
if err := command.CurrentMonitor().UnregisterHandler(VMEventCmd); err != nil {
495+
t.Fatalf("command.CurrentMonitor().UnregisterHandler(VMEventCmd) failed unexpectedly with error: %v", err)
496+
}
497+
})
498+
})
499+
}
500+
}
501+
425502
func TestConfigurePluginStates(t *testing.T) {
426503
if err := cfg.Load(nil); err != nil {
427504
t.Fatalf("Failed to load config: %v", err)

0 commit comments

Comments
 (0)