@@ -3,13 +3,15 @@ package ws1
33import (
44 "bytes"
55 "context"
6- "encoding/base64"
76 "encoding/json"
7+ "fmt"
88 "github.com/hazcod/crowdstrike-spotlight-slacker/config"
99 "github.com/pkg/errors"
1010 "github.com/sirupsen/logrus"
11- "io/ioutil "
11+ "io"
1212 "net/http"
13+ "net/http/httputil"
14+ "net/url"
1315 "strconv"
1416 "strings"
1517 "time"
@@ -31,46 +33,131 @@ type UserDeviceFinding struct {
3133 ComplianceName string
3234}
3335
34- func basicAuth ( username , password string ) string {
35- auth := username + ":" + password
36- return base64 . StdEncoding . EncodeToString ([] byte ( auth ))
36+ type LoggedRoundTripper struct {
37+ Proxied http. RoundTripper
38+ Logger * logrus. Logger
3739}
3840
39- func doAuthRequest (user , pass , apiKey , url , method string , payload interface {}) (respBytes []byte , err error ) {
41+ func (t LoggedRoundTripper ) RoundTrip (req * http.Request ) (res * http.Response , e error ) {
42+ resp , err := t .Proxied .RoundTrip (req )
43+
44+ if t .Logger != nil && t .Logger .IsLevelEnabled (logrus .TraceLevel ) {
45+ dumped , err := httputil .DumpRequest (req , true )
46+ if err != nil {
47+ t .Logger .WithError (err ).Error ("could not dump http request" )
48+ } else {
49+ t .Logger .Trace (string (dumped ))
50+ }
51+
52+ if req .Response == nil {
53+ t .Logger .Trace ("No response" )
54+ } else {
55+ dumped , err = httputil .DumpResponse (req .Response , true )
56+ if err != nil {
57+ t .Logger .WithError (err ).Error ("could not dump http response" )
58+ } else {
59+ t .Logger .Trace (string (dumped ))
60+ }
61+ }
62+ }
63+
64+ return resp , err
65+ }
66+
67+ type authResponse struct {
68+ Token string `json:"access_token"`
69+ Expires int `json:"expires_in"`
70+ Type string `json:"token_type"`
71+ }
72+
73+ func renewAuth (_ context.Context , ws1AuthLocation , clientID , secret string ) (token string , expiry time.Time , err error ) {
74+ data := url.Values {
75+ "client_id" : {clientID },
76+ "client_secret" : {secret },
77+ "grant_type" : {"client_credentials" },
78+ }
79+
80+ resp , err := http .PostForm (fmt .Sprintf ("https://%s.uemauth.vmwservices.com/connect/token" , ws1AuthLocation ), data )
81+ if err != nil {
82+ return "" , time.Time {}, errors .Wrap (err , "could not post to token endpoint" )
83+ }
84+
85+ if resp .StatusCode > 399 {
86+ return "" , time.Time {}, errors .Errorf ("token endpoint returned status code: %d" , resp .StatusCode )
87+ }
88+
89+ defer resp .Body .Close ()
90+
91+ respB , err := io .ReadAll (resp .Body )
92+ if err != nil {
93+ return "" , time.Time {}, errors .Wrap (err , "could not read token response" )
94+ }
95+
96+ logrus .Debugf ("%s" , string (respB ))
97+
98+ var response authResponse
99+ if err := json .Unmarshal (respB , & response ); err != nil {
100+ return "" , time.Time {}, errors .Wrap (err , "could not decode token response" )
101+ }
102+
103+ if ! strings .EqualFold (response .Type , "bearer" ) {
104+ return "" , time.Time {}, errors .Wrap (err , "not a bearer token" )
105+ }
106+
107+ if response .Expires <= 0 {
108+ return "" , time.Time {}, errors .New ("empty expires returned" )
109+ }
110+
111+ if response .Token == "" {
112+ return "" , time.Time {}, errors .New ("no token returned" )
113+ }
114+
115+ timeExpires := time .Now ().Add (time .Second * time .Duration (response .Expires ))
116+
117+ if timeExpires .Before (time .Now ()) {
118+ return "" , time.Time {}, errors .New ("token retrieved is already expired" )
119+ }
120+
121+ return response .Token , timeExpires , nil
122+ }
123+
124+ func doAuthRequest (ctx context.Context , ws1AuthLocation , clientID , secret , url , method string , payload interface {}) (respBytes []byte , err error ) {
40125 var reqPayload []byte
41126 if payload != nil {
42127 if reqPayload , err = json .Marshal (& payload ); err != nil {
43128 return nil , errors .Wrap (err , "coult not encode request body" )
44129 }
45130 }
46131
132+ token , _ , err := renewAuth (ctx , ws1AuthLocation , clientID , secret )
133+ if err != nil {
134+ return nil , errors .Wrap (err , "could not renew auth" )
135+ }
136+
47137 req , err := http .NewRequest (method , url , bytes .NewReader (reqPayload ))
138+ req = req .WithContext (ctx )
48139 if err != nil {
49140 return nil , errors .Wrap (err , "request failed" )
50141 }
51142
52143 req .Header .Set ("Accept" , "application/json" )
53- req .Header .Set ("aw-tenant-code" , apiKey )
54- req .Header .Set ("Authorization" , "Basic " + basicAuth (user , pass ))
55-
56- httpClient := http.Client {
57- Timeout : time .Second * 10 ,
58- }
144+ req .Header .Set ("Authorization" , fmt .Sprintf ("Bearer %s" , token ))
59145
146+ httpClient := http.Client {Timeout : time .Second * 10 }
60147 resp , err := httpClient .Do (req )
61148 if err != nil {
62149 return nil , errors .Wrap (err , "http request failed" )
63150 }
64151
65152 if resp .StatusCode > 399 {
66- respB , _ := ioutil .ReadAll (resp .Body )
153+ respB , _ := io .ReadAll (resp .Body )
67154 logrus .WithField ("response" , string (respB )).Warn ("invalid response" )
68155 return nil , errors .New ("invalid response code: " + strconv .Itoa (resp .StatusCode ))
69156 }
70157
71158 defer resp .Body .Close ()
72159
73- if respBytes , err = ioutil .ReadAll (resp .Body ); err != nil {
160+ if respBytes , err = io .ReadAll (resp .Body ); err != nil {
74161 return nil , errors .New ("could not read response body" )
75162 }
76163
@@ -79,7 +166,8 @@ func doAuthRequest(user, pass, apiKey, url, method string, payload interface{})
79166
80167func GetMessages (config * config.Config , ctx context.Context ) (map [string ]WS1Result , []string , error ) {
81168 deviceResponseB , err := doAuthRequest (
82- config .WS1 .User , config .WS1 .Password , config .WS1 .APIKey ,
169+ ctx ,
170+ config .WS1 .AuthLocation , config .WS1 .ClientID , config .WS1 .ClientSecret ,
83171 strings .TrimRight (config .WS1 .Endpoint , "/" )+ "/mdm/devices/search?compliance_status=NonCompliant" ,
84172 http .MethodGet ,
85173 nil ,
@@ -89,7 +177,7 @@ func GetMessages(config *config.Config, ctx context.Context) (map[string]WS1Resu
89177 return nil , nil , errors .Wrap (err , "could not fetch WS1 devices" )
90178 }
91179
92- usersWithDevices := []string {}
180+ usersWithDevices := make ( []string , 0 )
93181
94182 var devicesResponse DevicesResponse
95183 if err := json .Unmarshal (deviceResponseB , & devicesResponse ); err != nil {
0 commit comments