Skip to content

Commit 38cc939

Browse files
committed
Opt: Code quality optimization
1 parent a6ef6d4 commit 38cc939

File tree

4 files changed

+90
-136
lines changed

4 files changed

+90
-136
lines changed

pkg/control/control.go

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
package control
22

3-
import "time"
3+
import (
4+
"context"
5+
"github.com/Driver-C/tryssh/pkg/launcher"
6+
"sync"
7+
"time"
8+
)
49

510
const (
611
TypeUsers = "users"
@@ -10,3 +15,54 @@ const (
1015
TypeKeys = "keys"
1116
sshClientTimeoutWhenLogin = 5 * time.Second
1217
)
18+
19+
func ConcurrencyTryToConnect(concurrency int, connectors []launcher.Connector) []launcher.Connector {
20+
var hitConnectors []launcher.Connector
21+
var mutex sync.Mutex
22+
// If the number of connectors is less than the set concurrency, change the concurrency to the number of connectors
23+
if concurrency > len(connectors) {
24+
concurrency = len(connectors)
25+
}
26+
connectorsChan := make(chan launcher.Connector)
27+
ctx, cancelFunc := context.WithCancel(context.Background())
28+
// Producer
29+
go func(ctx context.Context, connectorsChan chan<- launcher.Connector, connectors []launcher.Connector) {
30+
for _, connector := range connectors {
31+
select {
32+
case <-ctx.Done():
33+
break
34+
default:
35+
connectorsChan <- connector
36+
}
37+
}
38+
close(connectorsChan)
39+
}(ctx, connectorsChan, connectors)
40+
// Consumer
41+
var wg sync.WaitGroup
42+
for i := 0; i < concurrency; i++ {
43+
wg.Add(1)
44+
go func(ctx context.Context, cancelFunc context.CancelFunc,
45+
connectorsChan <-chan launcher.Connector, cwg *sync.WaitGroup, mutex *sync.Mutex) {
46+
defer cwg.Done()
47+
for {
48+
select {
49+
case <-ctx.Done():
50+
return
51+
case connector, ok := <-connectorsChan:
52+
if !ok {
53+
return
54+
}
55+
if err := connector.TryToConnect(); err == nil {
56+
mutex.Lock()
57+
hitConnectors = append(hitConnectors, connector)
58+
mutex.Unlock()
59+
cancelFunc()
60+
}
61+
}
62+
}
63+
}(ctx, cancelFunc, connectorsChan, &wg, &mutex)
64+
}
65+
wg.Wait()
66+
cancelFunc()
67+
return hitConnectors
68+
}

pkg/control/scp.go

Lines changed: 12 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
package control
22

33
import (
4-
"context"
54
"github.com/Driver-C/tryssh/pkg/config"
65
"github.com/Driver-C/tryssh/pkg/launcher"
76
"github.com/Driver-C/tryssh/pkg/utils"
87
"strings"
9-
"sync"
108
"time"
119
)
1210

@@ -79,12 +77,18 @@ func (cc *ScpController) tryCopyWithCache(user string, targetServer *config.Serv
7977

8078
func (cc *ScpController) tryCopyWithoutCache(user string) {
8179
combinations := config.GenerateCombination(cc.destIp, user, cc.configuration)
82-
launchers := launcher.NewScpLaunchersByCombinations(combinations, cc.source, cc.destination, cc.recursive, cc.sshTimeout)
83-
hitLaunchers := concurrencyScpTryToConnect(cc.concurrency, launchers)
80+
launchers := launcher.NewScpLaunchersByCombinations(combinations, cc.source, cc.destination,
81+
cc.recursive, cc.sshTimeout)
82+
connectors := make([]launcher.Connector, len(launchers))
83+
for i, l := range launchers {
84+
connectors[i] = l
85+
}
86+
hitLaunchers := ConcurrencyTryToConnect(cc.concurrency, connectors)
8487
if hitLaunchers != nil {
8588
utils.Logger.Infoln("Login succeeded. The cache will be added.\n")
89+
hitLauncher := hitLaunchers[0].(*launcher.ScpLauncher)
8690
// The new server cache information
87-
newServerCache := launcher.GetConfigFromSshConnector(&hitLaunchers[0].SshConnector)
91+
newServerCache := launcher.GetConfigFromSshConnector(&hitLauncher.SshConnector)
8892
// Determine if the login attempt was successful after the old cache login failed.
8993
// If so, delete the old cache information that cannot be logged in after the login attempt is successful
9094
if cc.cacheIsFound {
@@ -100,10 +104,10 @@ func (cc *ScpController) tryCopyWithoutCache(user string) {
100104
utils.Logger.Infoln("Cache added.\n\n")
101105
// If the timeout time is less than sshClientTimeoutWhenLogin during login,
102106
// change to sshClientTimeoutWhenLogin
103-
if hitLaunchers[0].SshTimeout < sshClientTimeoutWhenLogin {
104-
hitLaunchers[0].SshTimeout = sshClientTimeoutWhenLogin
107+
if hitLauncher.SshTimeout < sshClientTimeoutWhenLogin {
108+
hitLauncher.SshTimeout = sshClientTimeoutWhenLogin
105109
}
106-
if !hitLaunchers[0].Launch() {
110+
if !hitLauncher.Launch() {
107111
utils.Logger.Errorf("Login failed.\n")
108112
}
109113
} else {
@@ -122,59 +126,6 @@ func (cc *ScpController) searchAliasExistsOrNot() {
122126
}
123127
}
124128

125-
func concurrencyScpTryToConnect(concurrency int, launchers []*launcher.ScpLauncher) []*launcher.ScpLauncher {
126-
var hitLaunchers []*launcher.ScpLauncher
127-
var mutex sync.Mutex
128-
var hostKeyMutex sync.Mutex
129-
// If the number of launchers is less than the set concurrency, change the concurrency to the number of launchers
130-
if concurrency > len(launchers) {
131-
concurrency = len(launchers)
132-
}
133-
launchersChan := make(chan *launcher.ScpLauncher)
134-
ctx, cancelFunc := context.WithCancel(context.Background())
135-
// Producer
136-
go func(ctx context.Context, launchersChan chan<- *launcher.ScpLauncher, launchers []*launcher.ScpLauncher) {
137-
for _, launcherP := range launchers {
138-
select {
139-
case <-ctx.Done():
140-
break
141-
default:
142-
launchersChan <- launcherP
143-
}
144-
}
145-
close(launchersChan)
146-
}(ctx, launchersChan, launchers)
147-
// Consumer
148-
var wg sync.WaitGroup
149-
for i := 0; i < concurrency; i++ {
150-
wg.Add(1)
151-
go func(ctx context.Context, cancelFunc context.CancelFunc,
152-
launchersChan <-chan *launcher.ScpLauncher, cwg *sync.WaitGroup, mutex *sync.Mutex) {
153-
defer cwg.Done()
154-
for {
155-
select {
156-
case <-ctx.Done():
157-
return
158-
case launcherP, ok := <-launchersChan:
159-
if !ok {
160-
return
161-
}
162-
launcherP.HostKeyMutex = &hostKeyMutex
163-
if err := launcherP.TryToConnect(); err == nil {
164-
mutex.Lock()
165-
hitLaunchers = append(hitLaunchers, launcherP)
166-
mutex.Unlock()
167-
cancelFunc()
168-
}
169-
}
170-
}
171-
}(ctx, cancelFunc, launchersChan, &wg, &mutex)
172-
}
173-
wg.Wait()
174-
cancelFunc()
175-
return hitLaunchers
176-
}
177-
178129
func NewScpController(source string, destination string, configuration *config.MainConfig) *ScpController {
179130
return &ScpController{
180131
source: source,

pkg/control/ssh.go

Lines changed: 10 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
package control
22

33
import (
4-
"context"
54
"github.com/Driver-C/tryssh/pkg/config"
65
"github.com/Driver-C/tryssh/pkg/launcher"
76
"github.com/Driver-C/tryssh/pkg/utils"
8-
"sync"
97
"time"
108
)
119

@@ -53,11 +51,16 @@ func (sc *SshController) tryLoginWithCache(user string, targetServer *config.Ser
5351
func (sc *SshController) tryLoginWithoutCache(user string) {
5452
combinations := config.GenerateCombination(sc.targetIp, user, sc.configuration)
5553
launchers := launcher.NewSshLaunchersByCombinations(combinations, sc.sshTimeout)
56-
hitLaunchers := concurrencySshTryToConnect(sc.concurrency, launchers)
54+
connectors := make([]launcher.Connector, len(launchers))
55+
for i, l := range launchers {
56+
connectors[i] = l
57+
}
58+
hitLaunchers := ConcurrencyTryToConnect(sc.concurrency, connectors)
5759
if hitLaunchers != nil {
5860
utils.Logger.Infoln("Login succeeded. The cache will be added.\n")
61+
hitLauncher := hitLaunchers[0].(*launcher.SshLauncher)
5962
// The new server cache information
60-
newServerCache := launcher.GetConfigFromSshConnector(&hitLaunchers[0].SshConnector)
63+
newServerCache := launcher.GetConfigFromSshConnector(&hitLauncher.SshConnector)
6164
// Determine if the login attempt was successful after the old cache login failed.
6265
// If so, delete the old cache information that cannot be logged in after the login attempt is successful
6366
if sc.cacheIsFound {
@@ -73,10 +76,10 @@ func (sc *SshController) tryLoginWithoutCache(user string) {
7376
utils.Logger.Infoln("Cache added.\n\n")
7477
// If the timeout time is less than sshClientTimeoutWhenLogin during login,
7578
// change to sshClientTimeoutWhenLogin
76-
if hitLaunchers[0].SshTimeout < sshClientTimeoutWhenLogin {
77-
hitLaunchers[0].SshTimeout = sshClientTimeoutWhenLogin
79+
if hitLauncher.SshTimeout < sshClientTimeoutWhenLogin {
80+
hitLauncher.SshTimeout = sshClientTimeoutWhenLogin
7881
}
79-
if !hitLaunchers[0].Launch() {
82+
if !hitLauncher.Launch() {
8083
utils.Logger.Errorf("Login failed.\n")
8184
}
8285
} else {
@@ -95,59 +98,6 @@ func (sc *SshController) searchAliasExistsOrNot() {
9598
}
9699
}
97100

98-
func concurrencySshTryToConnect(concurrency int, launchers []*launcher.SshLauncher) []*launcher.SshLauncher {
99-
var hitLaunchers []*launcher.SshLauncher
100-
var mutex sync.Mutex
101-
var hostKeyMutex sync.Mutex
102-
// If the number of launchers is less than the set concurrency, change the concurrency to the number of launchers
103-
if concurrency > len(launchers) {
104-
concurrency = len(launchers)
105-
}
106-
launchersChan := make(chan *launcher.SshLauncher)
107-
ctx, cancelFunc := context.WithCancel(context.Background())
108-
// Producer
109-
go func(ctx context.Context, launchersChan chan<- *launcher.SshLauncher, launchers []*launcher.SshLauncher) {
110-
for _, launcherP := range launchers {
111-
select {
112-
case <-ctx.Done():
113-
break
114-
default:
115-
launchersChan <- launcherP
116-
}
117-
}
118-
close(launchersChan)
119-
}(ctx, launchersChan, launchers)
120-
// Consumer
121-
var wg sync.WaitGroup
122-
for i := 0; i < concurrency; i++ {
123-
wg.Add(1)
124-
go func(ctx context.Context, cancelFunc context.CancelFunc,
125-
launchersChan <-chan *launcher.SshLauncher, cwg *sync.WaitGroup, mutex *sync.Mutex) {
126-
defer cwg.Done()
127-
for {
128-
select {
129-
case <-ctx.Done():
130-
return
131-
case launcherP, ok := <-launchersChan:
132-
if !ok {
133-
return
134-
}
135-
launcherP.HostKeyMutex = &hostKeyMutex
136-
if err := launcherP.TryToConnect(); err == nil {
137-
mutex.Lock()
138-
hitLaunchers = append(hitLaunchers, launcherP)
139-
mutex.Unlock()
140-
cancelFunc()
141-
}
142-
}
143-
}
144-
}(ctx, cancelFunc, launchersChan, &wg, &mutex)
145-
}
146-
wg.Wait()
147-
cancelFunc()
148-
return hitLaunchers
149-
}
150-
151101
func NewSshController(targetIp string, configuration *config.MainConfig) *SshController {
152102
return &SshController{
153103
targetIp: targetIp,

pkg/launcher/base.go

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ const (
1818
SSHKeyKeyword = "SSH-KEY"
1919
)
2020

21-
var keysMap = sync.Map{}
21+
var (
22+
keysMap = sync.Map{}
23+
hostKeyMutex = new(sync.Mutex)
24+
)
2225

2326
type Connector interface {
2427
Launch() bool
@@ -28,25 +31,19 @@ type Connector interface {
2831
}
2932

3033
type SshConnector struct {
31-
Ip string
32-
Port string
33-
User string
34-
Password string
35-
Key string
36-
SshTimeout time.Duration
37-
HostKeyMutex *sync.Mutex
34+
Ip string
35+
Port string
36+
User string
37+
Password string
38+
Key string
39+
SshTimeout time.Duration
3840
}
3941

4042
func (sc *SshConnector) Launch() bool {
4143
return false
4244
}
4345

4446
func (sc *SshConnector) LoadConfig() (config *ssh.ClientConfig) {
45-
// If no mutex is passed in, initialize one
46-
if sc.HostKeyMutex == nil {
47-
sc.HostKeyMutex = new(sync.Mutex)
48-
}
49-
5047
var authMethods []ssh.AuthMethod
5148
var privateKey []byte
5249
if sc.Key != "" {
@@ -70,7 +67,7 @@ func (sc *SshConnector) LoadConfig() (config *ssh.ClientConfig) {
7067
config = &ssh.ClientConfig{
7168
User: sc.User,
7269
Auth: authMethods,
73-
HostKeyCallback: trustedHostKeyCallback(searchKeyFromAddress(sc.Ip), sc.Ip, sc.HostKeyMutex),
70+
HostKeyCallback: trustedHostKeyCallback(searchKeyFromAddress(sc.Ip), sc.Ip, hostKeyMutex),
7471
Timeout: sc.SshTimeout,
7572
}
7673
return

0 commit comments

Comments
 (0)