Skip to content

Commit 157926a

Browse files
committed
add GetOrCreateService()
1 parent 7396050 commit 157926a

File tree

5 files changed

+162
-29
lines changed

5 files changed

+162
-29
lines changed

proxmox/service.go

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,78 @@
11
package proxmox
22

33
import (
4+
"crypto/sha256"
45
"crypto/tls"
56
"errors"
7+
"fmt"
68
"net/http"
9+
"sync"
710

811
"github.com/sp-yduck/proxmox-go/rest"
912
)
1013

14+
var (
15+
// global Service map against sessionKeys in map[sessionKey]Session
16+
sessionCache sync.Map
17+
18+
// mutex to control access to the GetOrCreate function to avoid duplicate
19+
// session creations on startup
20+
sessionMutex sync.Mutex
21+
)
22+
1123
type Service struct {
1224
restclient *rest.RESTClient
1325
}
1426

27+
type Params struct {
28+
// base endpoint of proxmox rest api
29+
endpoint string
30+
31+
// auth config
32+
authConfig AuthConfig
33+
34+
// rest client config
35+
clientConfig ClientConfig
36+
}
37+
38+
type ClientConfig struct {
39+
InsecureSkipVerify bool
40+
}
41+
1542
type AuthConfig struct {
16-
Username string
17-
Password string
18-
TokenID string
19-
Secret string
43+
// user or token
44+
AuthMethod string
45+
Username string
46+
Password string
47+
TokenID string
48+
Secret string
2049
}
2150

22-
func NewService(url string, authConfig AuthConfig, insecure bool) (*Service, error) {
23-
var loginOption rest.ClientOption
24-
if authConfig.Username != "" && authConfig.Password != "" {
25-
loginOption = rest.WithUserPassword(authConfig.Username, authConfig.Password)
26-
} else if authConfig.TokenID != "" && authConfig.Secret != "" {
27-
loginOption = rest.WithAPIToken(authConfig.TokenID, authConfig.Secret)
28-
} else {
29-
return nil, errors.New("invalid authentication config")
51+
func GetOrCreateService(params Params) (*Service, error) {
52+
sessionMutex.Lock()
53+
defer sessionMutex.Unlock()
54+
55+
key := retrieveSessionKey(params)
56+
if cachedSession, ok := sessionCache.Load(key); ok {
57+
return cachedSession.(*Service), nil
58+
}
59+
60+
s, err := NewService(params)
61+
if err != nil {
62+
return nil, err
63+
}
64+
sessionCache.Store(key, s)
65+
return s, nil
66+
}
67+
68+
func NewService(params Params) (*Service, error) {
69+
loginOption, err := makeLoginOpts(params.authConfig)
70+
if err != nil {
71+
return nil, err
3072
}
3173
clientOptions := []rest.ClientOption{loginOption}
3274

33-
if insecure {
75+
if params.clientConfig.InsecureSkipVerify {
3476
baseClient := &http.Client{
3577
Transport: &http.Transport{
3678
TLSClientConfig: &tls.Config{
@@ -41,7 +83,7 @@ func NewService(url string, authConfig AuthConfig, insecure bool) (*Service, err
4183
clientOptions = append(clientOptions, rest.WithClient(baseClient))
4284
}
4385

44-
restclient, err := rest.NewRESTClient(url, clientOptions...)
86+
restclient, err := rest.NewRESTClient(params.endpoint, clientOptions...)
4587
if err != nil {
4688
return nil, err
4789
}
@@ -96,3 +138,33 @@ func NewServiceWithAPIToken(url, tokenid, secret string, insecure bool) (*Servic
96138
func (s *Service) RESTClient() *rest.RESTClient {
97139
return s.restclient
98140
}
141+
142+
func makeLoginOpts(authConfig AuthConfig) (rest.ClientOption, error) {
143+
if authConfig.AuthMethod == "token" && authConfig.TokenID != "" && authConfig.Secret != "" {
144+
return rest.WithAPIToken(authConfig.TokenID, authConfig.Secret), nil
145+
} else if authConfig.AuthMethod == "user" && authConfig.Username != "" && authConfig.Password != "" {
146+
return rest.WithUserPassword(authConfig.Username, authConfig.Password), nil
147+
}
148+
return nil, errors.New("invalid authentication config")
149+
}
150+
151+
func retrieveSessionKey(params Params) string {
152+
var id string
153+
var secret []byte
154+
h := sha256.New()
155+
switch params.authConfig.AuthMethod {
156+
case "token":
157+
id = params.authConfig.TokenID
158+
h.Write([]byte(params.authConfig.Secret))
159+
secret = h.Sum(nil)
160+
case "user":
161+
id = params.authConfig.Username
162+
h.Write([]byte(params.authConfig.Password))
163+
secret = h.Sum(nil)
164+
default:
165+
id = params.authConfig.Username
166+
h.Write([]byte(params.authConfig.Password))
167+
secret = h.Sum(nil)
168+
}
169+
return fmt.Sprintf("%s#%s#%x", params.endpoint, id, secret)
170+
}

proxmox/service_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package proxmox
2+
3+
import (
4+
"os"
5+
"testing"
6+
)
7+
8+
func TestGetOrCreateService(t *testing.T) {
9+
url := os.Getenv("PROXMOX_URL")
10+
user := os.Getenv("PROXMOX_USERNAME")
11+
password := os.Getenv("PROXMOX_PASSWORD")
12+
tokeid := os.Getenv("PROXMOX_TOKENID")
13+
secret := os.Getenv("PROXMOX_SECRET")
14+
method := os.Getenv("PROXMOX_AUTH_METHOD")
15+
if url == "" {
16+
t.Fatal("url must not be empty")
17+
}
18+
19+
params := Params{
20+
endpoint: url,
21+
authConfig: AuthConfig{
22+
AuthMethod: method,
23+
Username: user,
24+
Password: password,
25+
TokenID: tokeid,
26+
Secret: secret,
27+
},
28+
clientConfig: ClientConfig{
29+
InsecureSkipVerify: true,
30+
},
31+
}
32+
33+
var svc *Service
34+
for i := 0; i < 10; i++ {
35+
s, err := GetOrCreateService(params)
36+
if err != nil {
37+
t.Fatalf("failed to get/create service: %v", err)
38+
}
39+
if i > 0 && s != svc {
40+
t.Fatalf("should not create new service: %v(cached)!=%v(new)", svc, s)
41+
}
42+
svc = s
43+
}
44+
}

proxmox/suite_test.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,26 @@ func (s *TestSuite) SetupSuite() {
1818
password := os.Getenv("PROXMOX_PASSWORD")
1919
tokeid := os.Getenv("PROXMOX_TOKENID")
2020
secret := os.Getenv("PROXMOX_SECRET")
21+
method := os.Getenv("PROXMOX_AUTH_METHOD")
2122
if url == "" {
2223
s.T().Fatal("url must not be empty")
2324
}
2425

25-
authConfig := AuthConfig{
26-
Username: user,
27-
Password: password,
28-
TokenID: tokeid,
29-
Secret: secret,
26+
params := Params{
27+
endpoint: url,
28+
authConfig: AuthConfig{
29+
AuthMethod: method,
30+
Username: user,
31+
Password: password,
32+
TokenID: tokeid,
33+
Secret: secret,
34+
},
35+
clientConfig: ClientConfig{
36+
InsecureSkipVerify: true,
37+
},
3038
}
3139

32-
service, err := NewService(url, authConfig, true)
40+
service, err := NewService(params)
3341
if err != nil {
3442
s.T().Fatalf("failed to create new service: %v", err)
3543
}

proxmox/websocket_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ func (s *TestSuite) TestVNCWebSocketClient() {
1818
s.T().Fatalf("write error: %v", err)
1919
}
2020

21-
ctx, _ := context.WithTimeout(context.TODO(), 10*time.Second)
21+
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
22+
defer cancel()
2223
out, _, err := client.Read(ctx)
2324
if err != nil {
2425
s.T().Fatalf("failed read message: %v", err)
@@ -35,7 +36,8 @@ func (s *TestSuite) TestExec() {
3536
}
3637
defer client.Close()
3738

38-
ctx, _ := context.WithTimeout(context.TODO(), 5*time.Second)
39+
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
40+
defer cancel()
3941
out, code, err := client.Exec(ctx, "whoami | base64 | base64 -d")
4042
if err != nil {
4143
s.T().Fatalf("failed to exec command: %s : %d : %v", out, code, err)
@@ -51,7 +53,8 @@ func (s *TestSuite) TestWriteFile() {
5153
}
5254
defer client.Close()
5355

54-
ctx, _ := context.WithTimeout(context.TODO(), 15*time.Second)
56+
ctx, cancel := context.WithTimeout(context.TODO(), 15*time.Second)
57+
defer cancel()
5558
err = client.WriteFile(ctx, "this is a file content", "~/test-write-file.txt")
5659
if err != nil {
5760
s.T().Fatalf("failed to exec command: %v", err)

rest/client.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@ import (
1010
"strings"
1111
"time"
1212

13-
"github.com/pkg/errors"
14-
1513
"github.com/sp-yduck/proxmox-go/api"
1614
)
1715

16+
const (
17+
defaultUserAgent = "sp-yduck/proxmox-go"
18+
)
19+
1820
type RESTClient struct {
1921
endpoint string
2022
httpClient *http.Client
23+
tokenid string
2124
token string
2225
session *api.Session
2326
credentials *TicketRequest
@@ -90,6 +93,7 @@ func WithUserPassword(username, password string) ClientOption {
9093

9194
func WithAPIToken(tokenid, secret string) ClientOption {
9295
return func(c *RESTClient) {
96+
c.tokenid = tokenid
9397
c.token = fmt.Sprintf("%s=%s", tokenid, secret)
9498
}
9599
}
@@ -165,9 +169,11 @@ func (c *RESTClient) makeAuthHeaders() http.Header {
165169
header.Add("Accept", "application/json")
166170
if c.token != "" {
167171
header.Add("Authorization", fmt.Sprintf("PVEAPIToken=%s", c.token))
172+
header.Add("User-Agent", fmt.Sprintf("%s:%s", defaultUserAgent, c.tokenid))
168173
} else if c.session != nil {
169174
header.Add("Cookie", fmt.Sprintf("PVEAuthCookie=%s", c.session.Ticket))
170175
header.Add("CSRFPreventionToken", c.session.CSRFPreventionToken)
176+
header.Add("User-Agent", fmt.Sprintf("%s:%s", defaultUserAgent, c.session.Username))
171177
}
172178
return header
173179
}
@@ -188,7 +194,7 @@ func checkResponse(res *http.Response) error {
188194

189195
body, err := io.ReadAll(res.Body)
190196
if err != nil {
191-
return errors.Errorf("failed to read body while handling http response of status %d : %v", res.StatusCode, err)
197+
return fmt.Errorf("failed to read body while handling http response of status %d : %v", res.StatusCode, err)
192198
}
193199

194200
if res.StatusCode == http.StatusInternalServerError || res.StatusCode == http.StatusNotImplemented {
@@ -204,10 +210,10 @@ func checkResponse(res *http.Response) error {
204210
return err
205211
}
206212
if body, ok := errorskey["errors"]; ok {
207-
return errors.Errorf("bad request: %s - %s", res.Status, body)
213+
return fmt.Errorf("bad request: %s - %s", res.Status, body)
208214
}
209-
return errors.Errorf("bad request: %s - %s", res.Status, string(body))
215+
return fmt.Errorf("bad request: %s - %s", res.Status, string(body))
210216
}
211217

212-
return errors.Errorf("code: %d", res.StatusCode)
218+
return fmt.Errorf("code: %d", res.StatusCode)
213219
}

0 commit comments

Comments
 (0)