Skip to content

Commit d833fff

Browse files
authored
add "sshproxyctl forget persist" (#30)
1 parent fc7cc91 commit d833fff

File tree

6 files changed

+178
-38
lines changed

6 files changed

+178
-38
lines changed

README.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ Version 2 brings a lot of changes to sshproxy:
157157
`sshproxyctl forget host -all|-host HOST [-port PORT]`
158158
- `sshproxyctl error_banner` (without any parameter) has been removed and
159159
replaced by `sshproxyctl forget error_banner`
160+
- `sshproxyctl forget persist [-user USER] [-service SERVICE] [-host HOST] [-port PORT]`
161+
has been added
160162

161163
Copying
162164
-------

cmd/sshproxyctl/sshproxyctl.go

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,24 @@ func disableHost(host, port, configFile string) error {
525525
return cli.SetHost(key, utils.Disabled, time.Now())
526526
}
527527

528+
func forgetPersist(user, service, host, port, configFile string) error {
529+
cli := mustInitEtcdClient(configFile)
530+
defer cli.Close()
531+
532+
history, err := cli.GetHistory(user, service, host, port)
533+
if err != nil {
534+
return err
535+
}
536+
537+
for _, kv := range history {
538+
err := cli.DelHistory(kv.User)
539+
if err != nil {
540+
return err
541+
}
542+
}
543+
return nil
544+
}
545+
528546
func setErrorBanner(errorBanner string, expire time.Time, configFile string) error {
529547
cli := mustInitEtcdClient(configFile)
530548
defer cli.Close()
@@ -598,7 +616,7 @@ The commands are:
598616
version show version number and exit
599617
show show states present in etcd
600618
enable enable a host in etcd
601-
forget forget a host/error_banner in etcd
619+
forget forget a host/error_banner/persist in etcd
602620
disable disable a host in etcd
603621
error_banner set the error banner in etcd
604622
@@ -676,17 +694,22 @@ Enable a previously disabled host in etcd.
676694
return fs
677695
}
678696

679-
func newForgetParser(allFlag *bool, hostString *string, portString *string) *flag.FlagSet {
697+
func newForgetParser(allFlag *bool, hostString, portString, userString, serviceString *string) *flag.FlagSet {
680698
fs := flag.NewFlagSet("forget", flag.ExitOnError)
681699
fs.BoolVar(allFlag, "all", false, "forget all hosts present in config")
682700
fs.StringVar(hostString, "host", "", "hostname to forget (can be a nodeset)")
683701
fs.StringVar(portString, "port", "", "port to forget (can be a nodeset)")
702+
fs.StringVar(userString, "user", "", "forget all persistent connections of this user")
703+
fs.StringVar(serviceString, "service", "", "forget all persistent connections of this service")
684704
fs.Usage = func() {
685705
fmt.Fprintf(flag.CommandLine.Output(), `Usage: %s forget COMMAND [OPTIONS]
686706
687707
The commands are:
688-
host -all|-host HOST [-port PORT] forget a host in etcd
689-
error_banner forget the error_banner in etcd
708+
host -all|-host HOST [-port PORT] forget a host in etcd
709+
error_banner forget the error_banner in etcd
710+
persist [-user USER] [-service SERVICE] [-host HOST] [-port PORT] forget a persistent connection in etcd
711+
(needs at least one option)
712+
(only connections matching all the options are forgotten)
690713
691714
The options are:
692715
`, os.Args[0])
@@ -866,13 +889,14 @@ func main() {
866889
var sourceString string
867890
var hostString string
868891
var portString string
892+
var serviceString string
869893

870894
parsers := map[string]*flag.FlagSet{
871895
"help": newHelpParser(),
872896
"version": newVersionParser(),
873897
"show": newShowParser(&csvFlag, &jsonFlag, &allFlag, &userString, &groupsString, &sourceString),
874898
"enable": newEnableParser(&allFlag, &hostString, &portString),
875-
"forget": newForgetParser(&allFlag, &hostString, &portString),
899+
"forget": newForgetParser(&allFlag, &hostString, &portString, &userString, &serviceString),
876900
"disable": newDisableParser(&allFlag, &hostString, &portString),
877901
"error_banner": newErrorBannerParser(&expire),
878902
}
@@ -977,6 +1001,12 @@ func main() {
9771001
}
9781002
case "error_banner":
9791003
delErrorBanner(*configFile)
1004+
case "persist":
1005+
if userString == "" && serviceString == "" && hostString == "" && portString == "" {
1006+
fmt.Fprintf(os.Stderr, "ERROR: missing '-user', '-service', '-host' or '-port'\n\n")
1007+
p.Usage()
1008+
}
1009+
forgetPersist(userString, serviceString, hostString, portString, *configFile)
9801010
}
9811011
case "disable":
9821012
p := parsers[cmd]

doc/sshproxyctl.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ COMMANDS
7070
*forget error_banner*::
7171
Remove the error banner in etcd.
7272

73+
*forget persist [-user USER] [-service SERVICE] [-host HOST] [-port PORT]*::
74+
Forget a persistent connection in etcd. Needs at least one option.
75+
Only connections matching all the options are forgotten.
76+
7377
*error_banner [-expire EXPIRATION] MESSAGE*::
7478
Set the error banner in etcd. 'MESSAGE' can be multiline. The error
7579
banner is displayed to the client when no backend can be reached (more

misc/sshproxyctl-completion.bash

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ _sshproxyctl() {
99
opts="-h -c ${commands}"
1010

1111
case "${prev}" in
12+
# Main commands
1213
disable)
1314
COMPREPLY=( $(compgen -W '-all -host -port' -- "${cur}") )
1415
;;
@@ -19,59 +20,68 @@ _sshproxyctl() {
1920
COMPREPLY=( $(compgen -W '-expire' -- "${cur}") )
2021
;;
2122
forget)
22-
COMPREPLY=( $(compgen -W '-all -host -port host error_banner' -- "${cur}") )
23+
COMPREPLY=( $(compgen -W '-all -host -port -service -user error_banner host persist' -- "${cur}") )
2324
;;
2425
help)
2526
COMPREPLY=( $(compgen -W "${commands}" -- "${cur}") )
2627
;;
2728
show)
28-
COMPREPLY=( $(compgen -W '-all -csv -json -user -groups -source connections hosts users groups error_banner config' -- "${cur}") )
29+
COMPREPLY=( $(compgen -W '-all -csv -groups -json -source -user config connections error_banner groups hosts users' -- "${cur}") )
30+
;;
31+
# Sub-commands
32+
config)
33+
COMPREPLY=( $(compgen -W '-groups -source -user' -- "${cur}") )
2934
;;
3035
connections)
3136
COMPREPLY=( $(compgen -W '-all -csv -json' -- "${cur}") )
3237
;;
38+
groups)
39+
COMPREPLY=( $(compgen -W '-all -csv -json' -- "${cur}") )
40+
;;
3341
host)
3442
COMPREPLY=( $(compgen -W '-all -host -port' -- "${cur}") )
3543
;;
3644
hosts)
3745
COMPREPLY=( $(compgen -W '-csv -json' -- "${cur}") )
3846
;;
39-
users)
40-
COMPREPLY=( $(compgen -W '-all -csv -json' -- "${cur}") )
47+
persist)
48+
COMPREPLY=( $(compgen -W '-host -port -service -user' -- "${cur}") )
4149
;;
42-
groups)
50+
users)
4351
COMPREPLY=( $(compgen -W '-all -csv -json' -- "${cur}") )
4452
;;
45-
config)
46-
COMPREPLY=( $(compgen -W '-user -groups -source' -- "${cur}") )
47-
;;
53+
# Options
4854
-all)
49-
COMPREPLY=( $(compgen -W '-csv -json -port connections users groups' -- "${cur}") )
55+
COMPREPLY=( $(compgen -W '-csv -json -port connections groups host users' -- "${cur}") )
5056
;;
5157
-csv)
52-
COMPREPLY=( $(compgen -W '-all connections hosts users groups' -- "${cur}") )
58+
COMPREPLY=( $(compgen -W '-all connections groups hosts users' -- "${cur}") )
5359
;;
5460
-groups)
55-
COMPREPLY=( $(compgen -W '-user -source config' -- "${cur}") )
61+
COMPREPLY=( $(compgen -W '-source -user config' -- "${cur}") )
5662
;;
5763
-host)
58-
COMPREPLY=( $(compgen -W '-port' -- "${cur}") )
64+
COMPREPLY=( $(compgen -W '-port -service -user host persist' -- "${cur}") )
5965
;;
6066
-json)
61-
COMPREPLY=( $(compgen -W '-all connections hosts users groups' -- "${cur}") )
67+
COMPREPLY=( $(compgen -W '-all connections groups hosts users' -- "${cur}") )
6268
;;
6369
-port)
64-
COMPREPLY=( $(compgen -W '-all -host' -- "${cur}") )
70+
COMPREPLY=( $(compgen -W '-all -host -service -user host persist' -- "${cur}") )
71+
;;
72+
-service)
73+
COMPREPLY=( $(compgen -W '-host -port -user persist' -- "${cur}") )
6574
;;
6675
-source)
67-
COMPREPLY=( $(compgen -W '-user -groups config' -- "${cur}") )
76+
COMPREPLY=( $(compgen -W '-groups -user config' -- "${cur}") )
6877
;;
6978
-user)
70-
COMPREPLY=( $(compgen -W '-groups -source config' -- "${cur}") )
79+
COMPREPLY=( $(compgen -W '-groups -host -port -service -source config persist' -- "${cur}") )
7180
;;
7281
-c)
7382
_filedir
7483
;;
84+
# Default
7585
*)
7686
COMPREPLY=( $(compgen -W "${opts}" -- "${cur}") )
7787
;;

pkg/utils/etcd.go

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@ import (
2020
"fmt"
2121
"os"
2222
"regexp"
23+
"slices"
2324
"sort"
2425
"strconv"
2526
"strings"
2627
"time"
2728

29+
"github.com/cea-hpc/sshproxy/pkg/nodesets"
30+
2831
"github.com/op/go-logging"
2932
"go.etcd.io/etcd/client/v3"
3033
"go.uber.org/zap"
@@ -399,6 +402,18 @@ func (c *Client) DelHost(hostport string) error {
399402
return nil
400403
}
401404

405+
// DelHistory deletes a history key (passed as "user@service") in etcd.
406+
func (c *Client) DelHistory(history string) error {
407+
key := toHistoryKey(history)
408+
ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout)
409+
_, err := c.cli.Delete(ctx, key, clientv3.WithPrefix())
410+
cancel()
411+
if err != nil {
412+
return err
413+
}
414+
return nil
415+
}
416+
402417
// SetHost sets a host (passed as "host:port") state and last checked time (ts)
403418
// in etcd.
404419
func (c *Client) SetHost(hostport string, state State, ts time.Time) error {
@@ -666,7 +681,7 @@ func (c *Client) GetAllHosts() ([]*FlatHost, error) {
666681
}
667682
}
668683

669-
history, err := c.GetAllHistory()
684+
history, err := c.GetHistory("", "", "", "")
670685
if err != nil {
671686
return nil, fmt.Errorf("ERROR: getting history from etcd: %v", err)
672687
}
@@ -748,7 +763,7 @@ func (c *Client) GetAllUsers(allFlag bool) ([]*FlatUser, error) {
748763
}
749764

750765
if allFlag {
751-
history, err := c.GetAllHistory()
766+
history, err := c.GetHistory("", "", "", "")
752767
if err != nil {
753768
return nil, fmt.Errorf("ERROR: getting history from etcd: %v", err)
754769
}
@@ -863,36 +878,56 @@ type FlatHistory struct {
863878
TTL int64
864879
}
865880

866-
// GetAllHistory returns a list of all history keys present in etcd.
867-
func (c *Client) GetAllHistory() ([]*FlatHistory, error) {
881+
// GetHistory returns a list of matching history keys present in etcd.
882+
func (c *Client) GetHistory(user, service, host, port string) ([]*FlatHistory, error) {
868883
ctx, cancel := context.WithTimeout(context.Background(), c.requestTimeout)
869-
resp, err := c.cli.Get(ctx, etcdHistoryPath, clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend))
884+
resp, err := c.cli.Get(ctx, etcdHistoryPath+"/"+user, clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend))
870885
defer cancel()
871886
if err != nil {
872887
return nil, err
873888
}
874889

875-
history := make([]*FlatHistory, len(resp.Kvs))
876-
for i, ev := range resp.Kvs {
890+
_, nodesetDlclose, nodesetExpand := nodesets.InitExpander()
891+
defer nodesetDlclose()
892+
hosts, err := nodesetExpand(host)
893+
if err != nil {
894+
return nil, err
895+
}
896+
ports, err := nodesetExpand(port)
897+
if err != nil {
898+
return nil, err
899+
}
900+
var history []*FlatHistory
901+
for _, ev := range resp.Kvs {
877902
subkey := string(ev.Key)[len(etcdHistoryPath)+1:]
878903
fields := strings.Split(subkey, "/")
879904
if len(fields) != 2 {
880905
return nil, fmt.Errorf("bad key format %s", subkey)
881906
}
882-
883-
v := &FlatHistory{}
884-
v.User = fields[0]
885-
v.Dest = string(ev.Value)
886-
leaseID, err := strconv.Atoi(fields[1])
907+
evHost, evPort, err := SplitHostPort(string(ev.Value))
887908
if err != nil {
888909
return nil, err
889910
}
890-
ttl, err := c.cli.TimeToLive(ctx, clientv3.LeaseID(leaseID))
891-
if err != nil {
892-
return nil, err
911+
912+
if (user == "" && service == "" && host == "" && port == "") ||
913+
((user == "" || strings.Contains("/"+fields[0], "/"+user+"@")) &&
914+
(service == "" || strings.Contains(fields[0]+"/", "@"+service+"/")) &&
915+
(host == "" || slices.Contains(hosts, evHost)) &&
916+
(port == "" || slices.Contains(ports, evPort))) {
917+
v := &FlatHistory{}
918+
v.User = fields[0]
919+
v.Dest = string(ev.Value)
920+
leaseID, err := strconv.Atoi(fields[1])
921+
if err != nil {
922+
return nil, err
923+
}
924+
ttl, err := c.cli.TimeToLive(ctx, clientv3.LeaseID(leaseID))
925+
if err != nil {
926+
return nil, err
927+
}
928+
v.TTL = ttl.TTL
929+
history = append(history, v)
893930
}
894-
v.TTL = ttl.TTL
895-
history[i] = v
896931
}
897932

898933
return history, nil

0 commit comments

Comments
 (0)