Skip to content

Commit 5ccd34b

Browse files
JordanMontgomerygeorgekarrv
authored andcommitted
Improve Microsoft endpoint validation (#38180)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #13698 If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files) for more information. - [x] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [x] If paths of existing endpoints are modified without backwards compatibility, checked the frontend/CLI for any necessary changes - [x] Added/updated automated tests - [x] Where appropriate, [automated tests simulate multiple hosts and test for host isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing) (updates to one hosts's records do not affect another) - [x] QA'd all new/changed functionality manually
1 parent a940109 commit 5ccd34b

File tree

9 files changed

+498
-27
lines changed

9 files changed

+498
-27
lines changed

changes/13698-windows-mdm

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* Improved SOAP message validation on Windows MDM endpoints

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
github.com/DATA-DOG/go-sqlmock v1.5.0
1111
github.com/Masterminds/semver v1.5.0
1212
github.com/Masterminds/semver/v3 v3.3.1
13+
github.com/MicahParks/jwkset v0.11.0
1314
github.com/RobotsAndPencils/buford v0.14.0
1415
github.com/VividCortex/mysqlerr v0.0.0-20170204212430-6c6b55f8796f
1516
github.com/WatchBeam/clock v0.0.0-20170901150240-b08e6b4da7ea

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7r
4848
github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
4949
github.com/Masterminds/sprig v2.22.0+incompatible h1:z4yfnGrZ7netVz+0EDJ0Wi+5VZCSYp4Z0m2dk6cEM60=
5050
github.com/Masterminds/sprig v2.22.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuNhlNS5hqE0NB0E6fgfo2Br3o=
51+
github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ=
52+
github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0=
5153
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
5254
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
5355
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=

pkg/mdm/mdmtest/windows.go

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@ package mdmtest
22

33
import (
44
"bytes"
5+
"crypto/rsa"
56
"crypto/tls"
67
"encoding/base64"
78
"encoding/xml"
9+
"errors"
810
"fmt"
911
"io"
1012
"net/http"
1113
"strconv"
14+
"strings"
1215

1316
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
1417
"github.com/fleetdm/fleet/v4/server/fleet"
@@ -38,6 +41,10 @@ type TestWindowsMDMClient struct {
3841
// queuedCommandResponses tracks the commands that will be sent next
3942
// time the device responds to the server.
4043
queuedCommandResponses map[string]fleet.SyncMLCmd
44+
// jwtSigningKey is the key used to sign JWTs
45+
jwtSigningKey *rsa.PrivateKey
46+
// jwtSigningKeyID is the ID to report in the header for the signing key
47+
jwtSigningKeyID string
4148
}
4249

4350
// This is a test-only enrollment type to force erroneous behavior.
@@ -55,6 +62,19 @@ func TestWindowsMDMClientDebug() TestWindowsMDMClientOption {
5562
}
5663
}
5764

65+
func TestWindowsMDMClientNotInOOBE() TestWindowsMDMClientOption {
66+
return func(c *TestWindowsMDMClient) {
67+
c.notInOOBE = true
68+
}
69+
}
70+
71+
func TestWindowsMDMClientWithSigningKey(signingKey *rsa.PrivateKey, signingKeyID string) TestWindowsMDMClientOption {
72+
return func(c *TestWindowsMDMClient) {
73+
c.jwtSigningKey = signingKey
74+
c.jwtSigningKeyID = signingKeyID
75+
}
76+
}
77+
5878
func NewTestMDMClientWindowsProgramatic(serverURL string, orbitNodeKey string, opts ...TestWindowsMDMClientOption) *TestWindowsMDMClient {
5979
return newTestMDMClient(serverURL, fleet.WindowsMDMProgrammaticEnrollmentType, orbitNodeKey, opts...)
6080
}
@@ -364,9 +384,20 @@ YioVozr1IWYySwWVzMf/SUwKZkKJCAJmSVcixE+4kxPkyPGyauIrN3wWC0zb+mjF
364384
</s:Envelope>
365385
`)
366386

367-
if _, err := c.request(microsoft_mdm.MDE2EnrollPath, enrollReq); err != nil {
387+
resp, err := c.request(microsoft_mdm.MDE2EnrollPath, enrollReq)
388+
if err != nil {
389+
return err
390+
}
391+
392+
// Check for SOAP fault
393+
defer resp.Body.Close()
394+
body, err := io.ReadAll(resp.Body)
395+
if err != nil {
368396
return err
369397
}
398+
if strings.Contains(string(body), "s:fault") {
399+
return fmt.Errorf("enroll request returned SOAP fault: %s", string(body))
400+
}
370401

371402
return nil
372403
}
@@ -406,10 +437,21 @@ func (c *TestWindowsMDMClient) Discovery() error {
406437

407438
// TODO: parse the response and store the policy and enroll endpoints instead
408439
// of hardcoding them to truly test that the server is behaving as expected.
409-
if _, err := c.request(microsoft_mdm.MDE2DiscoveryPath, discoveryReq); err != nil {
440+
resp, err := c.request(microsoft_mdm.MDE2DiscoveryPath, discoveryReq)
441+
if err != nil {
410442
return err
411443
}
412444

445+
// Check for SOAP fault
446+
defer resp.Body.Close()
447+
body, err := io.ReadAll(resp.Body)
448+
if err != nil {
449+
return err
450+
}
451+
if strings.Contains(string(body), "s:fault") {
452+
return fmt.Errorf("discovery request returned SOAP fault: %s", string(body))
453+
}
454+
413455
return nil
414456
}
415457

@@ -457,9 +499,19 @@ func (c *TestWindowsMDMClient) Policy() error {
457499

458500
// TODO: store the policy requirements to generate a certificate and generate
459501
// one on the fly using them instead of using hardcoded values.
460-
if _, err := c.request(microsoft_mdm.MDE2PolicyPath, policyReq); err != nil {
502+
resp, err := c.request(microsoft_mdm.MDE2PolicyPath, policyReq)
503+
if err != nil {
504+
return err
505+
}
506+
// Check for SOAP fault
507+
defer resp.Body.Close()
508+
body, err := io.ReadAll(resp.Body)
509+
if err != nil {
461510
return err
462511
}
512+
if strings.Contains(string(body), "s:fault") {
513+
return fmt.Errorf("policy request returned SOAP fault: %s", string(body))
514+
}
463515

464516
return nil
465517
}
@@ -493,9 +545,13 @@ func (c *TestWindowsMDMClient) getToken() (binarySecToken string, tokenValueType
493545
"unique_name": "foo_bar",
494546
"scp": "mdm_delegation",
495547
}
548+
if c.jwtSigningKey == nil || c.jwtSigningKeyID == "" {
549+
return "", "", errors.New("jwt signing key is not set")
550+
}
496551

497-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
498-
tokenString, err := token.SignedString([]byte("foo"))
552+
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
553+
token.Header["kid"] = c.jwtSigningKeyID
554+
tokenString, err := token.SignedString(c.jwtSigningKey)
499555
if err != nil {
500556
return "", "", err
501557
}

server/mdm/microsoft/wstep.go

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@ import (
1313
"errors"
1414
"fmt"
1515
"math/big"
16+
"os"
1617
"strconv"
1718
"strings"
1819
"time"
1920

21+
"github.com/MicahParks/jwkset"
2022
"github.com/fleetdm/fleet/v4/server"
23+
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
2124
"github.com/fleetdm/fleet/v4/server/mdm/microsoft/syncml"
2225
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/cryptoutil"
2326
"github.com/golang-jwt/jwt/v4"
@@ -215,7 +218,7 @@ func (m *manager) GetSTSAuthTokenUPNClaim(tokenStr string) (string, error) {
215218
}
216219

217220
// Since we used the private key to sign the tokens, we use the public counterpart to verify the signature
218-
token, err := jwt.ParseWithClaims(tokenStr, &STSClaims{}, func(token *jwt.Token) (interface{}, error) {
221+
token, err := jwt.ParseWithClaims(tokenStr, &STSClaims{}, func(token *jwt.Token) (any, error) {
219222
return m.identityCert.PublicKey, nil
220223
})
221224
if err != nil {
@@ -235,27 +238,71 @@ func (m *manager) GetSTSAuthTokenUPNClaim(tokenStr string) (string, error) {
235238

236239
// GetAzureAuthTokenClaims validates the given Azure AD token and returns
237240
// UPN, TenantID, UniqueName, DeviceID
238-
func GetAzureAuthTokenClaims(tokenStr string) (AzureData, error) {
241+
func GetAzureAuthTokenClaims(ctx context.Context, tokenStr string) (AzureData, error) {
239242
if len(tokenStr) == 0 {
240-
return AzureData{}, errors.New("invalid STS token")
243+
return AzureData{}, ctxerr.New(ctx, "invalid STS token")
241244
}
242245

243246
// Decode base64 token
244247
tokenBytes, err := base64.StdEncoding.DecodeString(tokenStr)
245248
if err != nil {
246-
return AzureData{}, errors.New("invalid Azure JWT token")
249+
return AzureData{}, ctxerr.Wrap(ctx, err, "invalid Azure JWT token")
247250
}
248251

249252
// Validate token format (header.payload.signature)
250253
parts := bytes.Split(tokenBytes, []byte("."))
251254
if len(parts) != 3 {
252-
return AzureData{}, errors.New("invalid Azure JWT format")
255+
return AzureData{}, ctxerr.New(ctx, "invalid Azure JWT format")
253256
}
254257

255258
// Parse JWT token
256-
token, _, err := new(jwt.Parser).ParseUnverified(string(tokenBytes), jwt.MapClaims{})
259+
jwksURI := "https://login.microsoftonline.com/common/discovery/v2.0/keys"
260+
var token *jwt.Token
261+
FLEET_DEV_AZURE_JWT_JWKS_URI := os.Getenv("FLEET_DEV_AZURE_JWT_JWKS_URI")
262+
if FLEET_DEV_AZURE_JWT_JWKS_URI != "" {
263+
jwksURI = FLEET_DEV_AZURE_JWT_JWKS_URI
264+
}
265+
266+
keys, err := jwkset.NewDefaultHTTPClient([]string{jwksURI})
267+
if err != nil {
268+
return AzureData{}, ctxerr.Wrap(ctx, err, "failed to retrieve Azure JWT signing keys")
269+
}
270+
token, err = jwt.Parse(string(tokenBytes), func(token *jwt.Token) (any, error) {
271+
tokenAlg, ok := token.Header["alg"]
272+
if !ok {
273+
return nil, errors.New("Azure JWT missing alg header")
274+
}
275+
tokenAlgStr, ok := tokenAlg.(string)
276+
if !ok {
277+
return nil, errors.New("invalid alg header in Azure JWT")
278+
}
279+
280+
kid, ok := token.Header["kid"]
281+
if !ok {
282+
return nil, errors.New("Azure JWT missing kid header")
283+
}
284+
kidStr, ok := kid.(string)
285+
if !ok {
286+
return nil, errors.New("invalid kid header in Azure JWT")
287+
}
288+
289+
key, err := keys.KeyRead(ctx, kidStr)
290+
if err != nil {
291+
if errors.Is(err, jwkset.ErrKeyNotFound) {
292+
return nil, fmt.Errorf("Azure JWT signed by unknown key: %w", err)
293+
}
294+
return nil, fmt.Errorf("failed to retrieve Azure JWT signing key: %w", err)
295+
}
296+
297+
// Alg is optional in the JWK but if present must match the token
298+
keyAlg := key.Marshal().ALG.String()
299+
if keyAlg != "" && keyAlg != tokenAlgStr {
300+
return nil, fmt.Errorf("Azure JWT signing key algorithm mismatch: expected %s from key, got %s", keyAlg, tokenAlgStr)
301+
}
302+
return key.Key(), nil
303+
})
257304
if err != nil {
258-
return AzureData{}, errors.New("parse error Azure JWT content")
305+
return AzureData{}, ctxerr.Wrap(ctx, err, "parse error Azure JWT content")
259306
}
260307

261308
// Parse JWT token
@@ -264,25 +311,25 @@ func GetAzureAuthTokenClaims(tokenStr string) (AzureData, error) {
264311
// Get UPN claim
265312
upnClaim, ok := claims["upn"].(string)
266313
if !ok || len(upnClaim) == 0 {
267-
return AzureData{}, errors.New("invalid UPN claim")
314+
return AzureData{}, ctxerr.New(ctx, "invalid UPN claim")
268315
}
269316

270317
// Get TenantID claim
271318
tenantIDClaim, ok := claims["tid"].(string)
272319
if !ok || len(tenantIDClaim) == 0 {
273-
return AzureData{}, errors.New("invalid TenantID claim")
320+
return AzureData{}, ctxerr.New(ctx, "invalid TenantID claim")
274321
}
275322

276323
// Get UniqueName claim
277324
uniqueNameClaim, ok := claims["unique_name"].(string)
278325
if !ok {
279-
return AzureData{}, errors.New("invalid UniqueName claim")
326+
return AzureData{}, ctxerr.New(ctx, "invalid UniqueName claim")
280327
}
281328

282329
// Get SCP claim
283330
azureSCPClaim, ok := claims["scp"].(string)
284331
if !ok || azureSCPClaim != "mdm_delegation" {
285-
return AzureData{}, errors.New("invalid SCP claim")
332+
return AzureData{}, ctxerr.New(ctx, "invalid SCP claim")
286333
}
287334

288335
return AzureData{

server/service/integration_logger_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ func (s *integrationLoggerTestSuite) TestWindowsMDMEnrollEmptyBinarySecurityToke
371371
host := createOrbitEnrolledHost(t, "windows", "", s.ds)
372372
mdmDevice := mdmtest.NewTestMDMClientWindowsEmptyBinarySecurityToken(s.server.URL, *host.OrbitNodeKey)
373373
err = mdmDevice.Enroll()
374-
require.NoError(t, err)
374+
require.Error(t, err)
375375

376376
t.Log(s.buf.String())
377377

@@ -394,12 +394,11 @@ func (s *integrationLoggerTestSuite) TestWindowsMDMEnrollEmptyBinarySecurityToke
394394
require.Equal(t, "info", m["level"])
395395
require.Equal(t, "binarySecurityToken is empty", m["soap_fault"])
396396
case microsoft_mdm.MDE2EnrollPath:
397-
require.Equal(t, "info", m["level"])
398-
require.Equal(t, "binarySecurityToken is empty", m["soap_fault"])
399-
foundEnroll = true
397+
foundEnroll = false
400398
}
401399
}
402400
require.True(t, foundDiscovery)
403401
require.True(t, foundPolicy)
404-
require.True(t, foundEnroll)
402+
// Will not enroll due to soap fault on prior request
403+
require.False(t, foundEnroll)
405404
}

server/service/integration_mdm_lifecycle_test.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,16 @@ func (s *integrationMDMTestSuite) TestTurnOnLifecycleEventsWindows() {
326326
// the ack of the message should be the only returned command
327327
require.Len(t, cmds, 1)
328328

329+
// Simulate the host having fleetd installed and reporting back in as un-enrolled
330+
mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error {
331+
_, err := q.ExecContext(context.Background(), `
332+
UPDATE host_mdm
333+
SET enrolled = 0, server_url = ''
334+
WHERE host_id = ?
335+
`, host.ID)
336+
return err
337+
})
338+
329339
// re-enroll
330340
require.NoError(t, device.Enroll())
331341
},
@@ -358,13 +368,28 @@ func (s *integrationMDMTestSuite) TestTurnOnLifecycleEventsWindows() {
358368
&orbitScriptResp,
359369
)
360370

371+
// Simulate the host having fleetd installed after being wiped and reporting back in as un-enrolled
372+
mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error {
373+
_, err := q.ExecContext(context.Background(), `
374+
UPDATE host_mdm
375+
SET enrolled = 0, server_url = ''
376+
WHERE host_id = ?
377+
`, host.ID)
378+
return err
379+
})
380+
361381
require.NoError(t, device.Enroll())
362382
},
363383
},
364384
{
365385
"host turns on MDM features out of the blue",
366386
func(t *testing.T, host *fleet.Host, device *mdmtest.TestWindowsMDMClient) {
367-
require.NoError(t, device.Enroll())
387+
if strings.Contains(t.Name(), "automatic") {
388+
require.NoError(t, device.Enroll())
389+
} else {
390+
// A programatically-enrolled host that randomly turns on MDM after already enabled will get a SOAP fault
391+
require.Error(t, device.Enroll())
392+
}
368393
},
369394
},
370395
{
@@ -437,7 +462,7 @@ func (s *integrationMDMTestSuite) TestTurnOnLifecycleEventsWindows() {
437462
host := createOrbitEnrolledHost(t, "windows", "windows_automatic", s.ds)
438463

439464
azureMail := "foo.bar.baz@example.com"
440-
device := mdmtest.NewTestMDMClientWindowsAutomatic(s.server.URL, azureMail)
465+
device := mdmtest.NewTestMDMClientWindowsAutomatic(s.server.URL, azureMail, mdmtest.TestWindowsMDMClientWithSigningKey(s.jwtSigningKey, defaultFakeJWTKeyID))
441466
device.HardwareID = host.UUID
442467
device.DeviceID = host.UUID
443468
require.NoError(t, device.Enroll())

0 commit comments

Comments
 (0)