Skip to content

Commit 356404e

Browse files
committed
feat: switch ws1 to oauth2
1 parent df9fb60 commit 356404e

File tree

3 files changed

+120
-21
lines changed

3 files changed

+120
-21
lines changed

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
all: run
2+
13
clean:
24
rm slacker || true
35

@@ -6,5 +8,4 @@ build:
68
chmod +x slacker
79

810
run:
9-
chmod +x slacker
10-
./slacker -dry -config=test.yml
11+
go run ./cmd/ -dry -config=test.yml -noreport -log=trace

config/config.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ type Config struct {
3535

3636
WS1 struct {
3737
Endpoint string `yaml:"api_url" env:"WS1_API_URL"`
38-
APIKey string `yaml:"api_key" env:"WS1_API_KEY"`
39-
User string `yaml:"user" env:"WS1_USER"`
40-
Password string `yaml:"password" env:"WS1_PASSWORD"`
38+
// from https://docs.vmware.com/en/VMware-Workspace-ONE-UEM/services/UEM_ConsoleBasics/GUID-BF20C949-5065-4DCF-889D-1E0151016B5A.html
39+
// e.g. 'emea'
40+
AuthLocation string `yaml:"auth_location" env:"WS1_AUTH_LOCATION"`
41+
ClientID string `yaml:"client_id" env:"WS1_CLIENT_ID"`
42+
ClientSecret string `yaml:"client_secret" env:"WS1_CLIENT_SECRET"`
4143

4244
SkipFilters []struct {
4345
Policy string `yaml:"policy"`
@@ -104,5 +106,13 @@ func (c *Config) Validate() error {
104106
return errors.New("missing message")
105107
}
106108

109+
if c.WS1.ClientSecret == "" || c.WS1.ClientID == "" {
110+
return errors.New("missing WS1 client_id or client_secret")
111+
}
112+
113+
if c.WS1.AuthLocation == "" {
114+
return errors.New("missing WS1 auth_location")
115+
}
116+
107117
return nil
108118
}

pkg/ws1/extractor.go

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ package ws1
33
import (
44
"bytes"
55
"context"
6-
"encoding/base64"
76
"encoding/json"
7+
"fmt"
88
"github.com/hazcod/crowdstrike-spotlight-slacker/config"
99
"github.com/pkg/errors"
1010
"github.com/sirupsen/logrus"
11-
"io/ioutil"
11+
"io"
1212
"net/http"
13+
"net/http/httputil"
14+
"net/url"
1315
"strconv"
1416
"strings"
1517
"time"
@@ -31,46 +33,131 @@ type UserDeviceFinding struct {
3133
ComplianceName string
3234
}
3335

34-
func basicAuth(username, password string) string {
35-
auth := username + ":" + password
36-
return base64.StdEncoding.EncodeToString([]byte(auth))
36+
type LoggedRoundTripper struct {
37+
Proxied http.RoundTripper
38+
Logger *logrus.Logger
3739
}
3840

39-
func doAuthRequest(user, pass, apiKey, url, method string, payload interface{}) (respBytes []byte, err error) {
41+
func (t LoggedRoundTripper) RoundTrip(req *http.Request) (res *http.Response, e error) {
42+
resp, err := t.Proxied.RoundTrip(req)
43+
44+
if t.Logger != nil && t.Logger.IsLevelEnabled(logrus.TraceLevel) {
45+
dumped, err := httputil.DumpRequest(req, true)
46+
if err != nil {
47+
t.Logger.WithError(err).Error("could not dump http request")
48+
} else {
49+
t.Logger.Trace(string(dumped))
50+
}
51+
52+
if req.Response == nil {
53+
t.Logger.Trace("No response")
54+
} else {
55+
dumped, err = httputil.DumpResponse(req.Response, true)
56+
if err != nil {
57+
t.Logger.WithError(err).Error("could not dump http response")
58+
} else {
59+
t.Logger.Trace(string(dumped))
60+
}
61+
}
62+
}
63+
64+
return resp, err
65+
}
66+
67+
type authResponse struct {
68+
Token string `json:"access_token"`
69+
Expires int `json:"expires_in"`
70+
Type string `json:"token_type"`
71+
}
72+
73+
func renewAuth(_ context.Context, ws1AuthLocation, clientID, secret string) (token string, expiry time.Time, err error) {
74+
data := url.Values{
75+
"client_id": {clientID},
76+
"client_secret": {secret},
77+
"grant_type": {"client_credentials"},
78+
}
79+
80+
resp, err := http.PostForm(fmt.Sprintf("https://%s.uemauth.vmwservices.com/connect/token", ws1AuthLocation), data)
81+
if err != nil {
82+
return "", time.Time{}, errors.Wrap(err, "could not post to token endpoint")
83+
}
84+
85+
if resp.StatusCode > 399 {
86+
return "", time.Time{}, errors.Errorf("token endpoint returned status code: %d", resp.StatusCode)
87+
}
88+
89+
defer resp.Body.Close()
90+
91+
respB, err := io.ReadAll(resp.Body)
92+
if err != nil {
93+
return "", time.Time{}, errors.Wrap(err, "could not read token response")
94+
}
95+
96+
logrus.Debugf("%s", string(respB))
97+
98+
var response authResponse
99+
if err := json.Unmarshal(respB, &response); err != nil {
100+
return "", time.Time{}, errors.Wrap(err, "could not decode token response")
101+
}
102+
103+
if !strings.EqualFold(response.Type, "bearer") {
104+
return "", time.Time{}, errors.Wrap(err, "not a bearer token")
105+
}
106+
107+
if response.Expires <= 0 {
108+
return "", time.Time{}, errors.New("empty expires returned")
109+
}
110+
111+
if response.Token == "" {
112+
return "", time.Time{}, errors.New("no token returned")
113+
}
114+
115+
timeExpires := time.Now().Add(time.Second * time.Duration(response.Expires))
116+
117+
if timeExpires.Before(time.Now()) {
118+
return "", time.Time{}, errors.New("token retrieved is already expired")
119+
}
120+
121+
return response.Token, timeExpires, nil
122+
}
123+
124+
func doAuthRequest(ctx context.Context, ws1AuthLocation, clientID, secret, url, method string, payload interface{}) (respBytes []byte, err error) {
40125
var reqPayload []byte
41126
if payload != nil {
42127
if reqPayload, err = json.Marshal(&payload); err != nil {
43128
return nil, errors.Wrap(err, "coult not encode request body")
44129
}
45130
}
46131

132+
token, _, err := renewAuth(ctx, ws1AuthLocation, clientID, secret)
133+
if err != nil {
134+
return nil, errors.Wrap(err, "could not renew auth")
135+
}
136+
47137
req, err := http.NewRequest(method, url, bytes.NewReader(reqPayload))
138+
req = req.WithContext(ctx)
48139
if err != nil {
49140
return nil, errors.Wrap(err, "request failed")
50141
}
51142

52143
req.Header.Set("Accept", "application/json")
53-
req.Header.Set("aw-tenant-code", apiKey)
54-
req.Header.Set("Authorization", "Basic "+basicAuth(user, pass))
55-
56-
httpClient := http.Client{
57-
Timeout: time.Second * 10,
58-
}
144+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
59145

146+
httpClient := http.Client{Timeout: time.Second * 10}
60147
resp, err := httpClient.Do(req)
61148
if err != nil {
62149
return nil, errors.Wrap(err, "http request failed")
63150
}
64151

65152
if resp.StatusCode > 399 {
66-
respB, _ := ioutil.ReadAll(resp.Body)
153+
respB, _ := io.ReadAll(resp.Body)
67154
logrus.WithField("response", string(respB)).Warn("invalid response")
68155
return nil, errors.New("invalid response code: " + strconv.Itoa(resp.StatusCode))
69156
}
70157

71158
defer resp.Body.Close()
72159

73-
if respBytes, err = ioutil.ReadAll(resp.Body); err != nil {
160+
if respBytes, err = io.ReadAll(resp.Body); err != nil {
74161
return nil, errors.New("could not read response body")
75162
}
76163

@@ -79,7 +166,8 @@ func doAuthRequest(user, pass, apiKey, url, method string, payload interface{})
79166

80167
func GetMessages(config *config.Config, ctx context.Context) (map[string]WS1Result, []string, error) {
81168
deviceResponseB, err := doAuthRequest(
82-
config.WS1.User, config.WS1.Password, config.WS1.APIKey,
169+
ctx,
170+
config.WS1.AuthLocation, config.WS1.ClientID, config.WS1.ClientSecret,
83171
strings.TrimRight(config.WS1.Endpoint, "/")+"/mdm/devices/search?compliance_status=NonCompliant",
84172
http.MethodGet,
85173
nil,
@@ -89,7 +177,7 @@ func GetMessages(config *config.Config, ctx context.Context) (map[string]WS1Resu
89177
return nil, nil, errors.Wrap(err, "could not fetch WS1 devices")
90178
}
91179

92-
usersWithDevices := []string{}
180+
usersWithDevices := make([]string, 0)
93181

94182
var devicesResponse DevicesResponse
95183
if err := json.Unmarshal(deviceResponseB, &devicesResponse); err != nil {

0 commit comments

Comments
 (0)