Skip to content

Az.StorageSync | Added TenantId of ARC Server and checked with StorageSyncService tenant to prevent unsupported configuration #28355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ public MockServerManagedIdentityProvider(string testName)

public bool EnableMIChecking { get; set; }

public Guid GetServerApplicationId(LocalServerType serverType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true)
public Task<ServerApplicationIdentity> GetServerApplicationIdentityAsync(LocalServerType serverType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true)
{
return Guid.Empty;
return Task.FromResult(new ServerApplicationIdentity(Guid.Empty, Guid.Empty));
}

public LocalServerType GetServerType(IEcsManagement ecsManagement)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using Microsoft.Azure.Commands.StorageSync.Common;
using Microsoft.Azure.Commands.StorageSync.Common.Extensions;
using Microsoft.Azure.Commands.StorageSync.InternalObjects;
using Microsoft.Azure.Commands.StorageSync.Interop.ManagedIdentity;
using Microsoft.Azure.Management.StorageSync.Models;
using Newtonsoft.Json;
using System;
Expand Down Expand Up @@ -372,17 +373,18 @@ private bool TryCreateDirectory(string monitoringDataPath, out DirectoryInfo dir
return false;
}

public override Guid? GetApplicationIdOrNull()
public override ServerApplicationIdentity GetServerApplicationIdentityOrNull()
{
if(TestName == "TestNewRegisteredServerWithIdentityError")
var testTenantGuid = new Guid("0483643a-cb2f-462a-bc27-1a270e5bdc0a");
if (TestName == "TestNewRegisteredServerWithIdentityError")
{
return null;
}
if(TestName == "TestPatchRegisteredServer")
{
return new Guid("3b990c8b-9607-4c2a-8b04-1d41985facca");
return new ServerApplicationIdentity(new Guid("3b990c8b-9607-4c2a-8b04-1d41985facca"), testTenantGuid);
}
return Guid.NewGuid();
return new ServerApplicationIdentity(Guid.NewGuid(), testTenantGuid);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ public MockSyncServerRegistrationClientBase(IEcsManagement ecsManagementInteropC
/// <summary>
/// This function will return the application id of the server if it is available.
/// </summary>
/// <returns>Application Id or null</returns>
public abstract Guid? GetApplicationIdOrNull();
/// <returns>ServerApplicationIdentity or null</returns>
public abstract ServerApplicationIdentity GetServerApplicationIdentityOrNull();

/// <summary>
/// Validate sync server registration.
Expand Down Expand Up @@ -146,6 +146,7 @@ public void Dispose()
/// 4. Get ClusterInfo
/// 5. Populate RegistrationServerResource
/// </summary>
/// <param name="storageSyncServiceTenantId">Storage Sync Service Tenant Id</param>
/// <param name="managementEndpointUri">Management endpoint Uri</param>
/// <param name="subscriptionId">Subscription Id</param>
/// <param name="storageSyncServiceName">Storage Sync Service Name</param>
Expand All @@ -162,6 +163,7 @@ public void Dispose()
/// </exception>
/// <exception cref="ServerRegistrationException"></exception>
public RegisteredServer Register(
string storageSyncServiceTenantId,
Uri managementEndpointUri,
Guid subscriptionId,
string storageSyncServiceName,
Expand All @@ -176,7 +178,19 @@ public RegisteredServer Register(
bool assignIdentity)
{
// Get ApplicationId
Guid? applicationId = assignIdentity ? GetApplicationIdOrNull() : null;
ServerApplicationIdentity serverApplicationIdentity = assignIdentity ? GetServerApplicationIdentityOrNull() : null;
// Discover the server type , Get the application id,
Guid? applicationId = serverApplicationIdentity?.ApplicationId;

if (serverApplicationIdentity != null && serverApplicationIdentity.TenantId != Guid.Empty)
{
// Check that tenants match
if (!string.Equals(storageSyncServiceTenantId, serverApplicationIdentity.TenantId.ToString(), StringComparison.OrdinalIgnoreCase))
{
throw new ServerRegistrationException(
$"Cross-tenant registration is not allowed. The server belongs to tenant '{serverApplicationIdentity.TenantId}' but the Storage Sync Service is in tenant '{storageSyncServiceTenantId}'.");
}
}

#pragma warning disable CA1416 // Validate platform compatibility
//RegistryUtility.WriteValue(StorageSyncConstants.ServerAuthRegistryKeyName,
Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/StorageSync/StorageSync/ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
- Additional information about change #1
-->
## Upcoming Release

* Fixed security bug in checking tenant id for MI server registration
## Version 2.5.1
* Fixed security bug in token acquisition for MI server registration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using System.Management;
using System.Management.Automation;
using System.Text.RegularExpressions;
using System.Threading.Tasks;

namespace Commands.StorageSync.Interop.Clients
{
Expand Down Expand Up @@ -406,13 +407,13 @@ private bool TryCreateDirectory(string monitoringDataPath, out DirectoryInfo dir
/// This function will get the application id of the server if identity is available.
/// </summary>
/// <returns>Application id or null.</returns>
public override Guid? GetApplicationIdOrNull()
public async override Task<ServerApplicationIdentity> GetServerApplicationIdentityOrNull()
{
LocalServerType localServerType = this.ServerManagedIdentityProvider.GetServerType(this.EcsManagementInteropClient);

if(localServerType != LocalServerType.HybridServer)
if (localServerType != LocalServerType.HybridServer)
{
return this.ServerManagedIdentityProvider.GetServerApplicationId(localServerType, throwIfNotFound: true, validateSystemAssignedManagedIdentity: true);
return await this.ServerManagedIdentityProvider.GetServerApplicationIdentityAsync(localServerType, throwIfNotFound: true, validateSystemAssignedManagedIdentity: true);
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public abstract ServerRegistrationData Setup(
/// This function will return the application id of the server if it is available.
/// </summary>
/// <returns>Application Id or null</returns>
public abstract Guid? GetApplicationIdOrNull();
public abstract Task<ServerApplicationIdentity> GetServerApplicationIdentityOrNull();

/// <summary>
/// Dispose method for cleaning Interop client object.
Expand All @@ -146,6 +146,7 @@ public void Dispose()
/// 3. Calls RegisterOnline callback to make ARM call (from caller context)
/// 4. Persists registered server resource from cloud to local FileSyncSvc service
/// </summary>
/// <param name="storageSyncServiceTenantId">Storage Sync Service TenantId</param>
/// <param name="managementEndpointUri">Management endpoint Uri</param>
/// <param name="subscriptionId">Subscription Id</param>
/// <param name="storageSyncServiceName">Storage Sync Service Name</param>
Expand All @@ -160,6 +161,7 @@ public void Dispose()
/// <param name="assignIdentity">Assign Identity</param>
/// <returns>Registered Server Resource</returns>
public RegisteredServer Register(
string storageSyncServiceTenantId,
Uri managementEndpointUri,
Guid subscriptionId,
string storageSyncServiceName,
Expand All @@ -174,7 +176,18 @@ public RegisteredServer Register(
bool assignIdentity)
{
// Discover the server type , Get the application id,
Guid? applicationId = assignIdentity ? GetApplicationIdOrNull() : null;
ServerApplicationIdentity serverApplicationIdentity = assignIdentity ? GetServerApplicationIdentityOrNull().GetAwaiter().GetResult() : null;
Guid? applicationId = serverApplicationIdentity?.ApplicationId;

if (serverApplicationIdentity != null && serverApplicationIdentity.TenantId != Guid.Empty)
{
// Check that tenants match
if (!string.Equals(storageSyncServiceTenantId, serverApplicationIdentity.TenantId.ToString(), StringComparison.OrdinalIgnoreCase))
{
throw new ServerRegistrationException(
$"Cross-tenant registration is not allowed. The server belongs to tenant '{serverApplicationIdentity.TenantId}' but the Storage Sync Service is in tenant '{storageSyncServiceTenantId}'.");
}
}

// Set the registry key for ServerAuthType
RegistryUtility.WriteValue(StorageSyncConstants.ServerAuthRegistryKeyName,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Commands.StorageSync.Interop.Interfaces;
using Microsoft.Azure.Commands.StorageSync.Interop.Enums;
using System;
using System.Threading.Tasks;

namespace Microsoft.Azure.Commands.StorageSync.Interop.ManagedIdentity
{
Expand All @@ -20,14 +21,11 @@ public interface IServerManagedIdentityProvider
LocalServerType GetServerType(IEcsManagement ecsManagement);

/// <summary>
/// Gets the server's application id by trying to get a token and parsing for the oid
/// We choose to get the applicationId from the token rather than making a Get call on the resource
/// because we don't know the permissions the user has on the resource
/// Gets the server's application identity (application ID and tenant ID) asynchronously by trying to get a token from the Arc/Azure IMDS endpoint and parsing for the oid and tenant ID.
/// </summary>
/// <param name="serverType">ServerType: Hybrid or Azure</param>
/// <param name="throwIfNotFound">Whether to throw an exception if an Application ID is not available</param>
/// <param name="validateSystemAssignedManagedIdentity">Whether to validate that the Application Id belongs to a System-Assigned Managed Identity</param>
/// <returns>Server's applicationId if it's available, Guid.Empty otherwise</returns>
Guid GetServerApplicationId(LocalServerType serverType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true);
Task<ServerApplicationIdentity> GetServerApplicationIdentityAsync(LocalServerType serverType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public interface ISyncServerRegistration : IDisposable
/// 2. Sets up ServerRegistrationData
/// 3. Calls RegisterOnline callback to make ARM call (from caller context)
/// 4. Persists registered server resource from cloud to local FileSyncSvc service
/// <param name="storageSyncServiceTenantId">Storage Sync Service TenantId</param>
/// <param name="managementEndpointUri">Management endpoint Uri</param>
/// <param name="subscriptionId">Subscription Id</param>
/// <param name="storageSyncServiceName">Storage Sync Service Name</param>
Expand All @@ -47,6 +48,7 @@ public interface ISyncServerRegistration : IDisposable
/// <returns>Registered Server Resource</returns>
/// </summary>
RegisteredServer Register(
string storageSyncServiceTenantId,
Uri managementEndpointUri,
Guid subscriptionId,
string storageSyncServiceName,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;

namespace Microsoft.Azure.Commands.StorageSync.Interop.ManagedIdentity
{
/// <summary>
/// ServerApplicationIdentity represents the server's application identity with application ID and tenant ID.
/// </summary>
public class ServerApplicationIdentity
{
public Guid ApplicationId { get; set; }
public Guid TenantId { get; set; }

public ServerApplicationIdentity(Guid applicationId, Guid tenantId)
{
ApplicationId = applicationId;
TenantId = tenantId;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace Microsoft.Azure.Commands.StorageSync.Interop.ManagedIdentity
/// </summary>
public class ServerManagedIdentityProvider : IServerManagedIdentityProvider
{
private const string ResourceStorageUri = "https://storage.azure.com/";

public bool EnableMIChecking { get; set; }

public Action<string, EventLevel> TraceLog { get; private set; }
Expand Down Expand Up @@ -47,19 +49,20 @@ public LocalServerType GetServerType(IEcsManagement ecsManagement)
/// </summary>
/// <param name="localServerType">ServerType: Hybrid or Azure</param>
/// <param name="throwIfNotFound">Whether to throw an exception if an Application ID is not available</param>
/// <param name="validateSAMI">Whether to validate that the Application Id belongs to a System-Assigned Managed Identity</param>
/// <param name="validateSystemAssignedManagedIdentity">Whether to validate that the Application Id belongs to a System-Assigned Managed Identity</param>
/// <returns>Server's applicationId if it's available, Guid.Empty otherwise</returns>
public Guid GetServerApplicationId(LocalServerType localServerType, bool throwIfNotFound = true, bool validateSAMI = true)
public async Task<ServerApplicationIdentity> GetServerApplicationIdentityAsync(LocalServerType localServerType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true)
{
var applicationId = Guid.Empty;
Guid applicationId = Guid.Empty;
Guid tenantId = Guid.Empty;

if (EnableMIChecking)
{
try
{
if (localServerType == LocalServerType.HybridServer)
{
return applicationId;
return new ServerApplicationIdentity(applicationId, tenantId);
}

// We need to use the https://storage.azure.com resource, as this provides us the x-ms-rid header to use for validation.
Expand All @@ -68,14 +71,12 @@ public Guid GetServerApplicationId(LocalServerType localServerType, bool throwIf
// When we cache token in ServerManagedIdentityTokenProvider, it will use ProtectedMemory to encrypt/decrypt the token,
// and this GetServerApplicationId can be triggered from server registration using PowerShell Core which causes that issue.
// So this is another reason we need to get the token from IMDS endpoint directly via ServerManagedIdentityUtils, not ServerManagedIdentityTokenProvider.
ServerManagedIdentityTokenResponse tokenResponse;

tokenResponse = ServerManagedIdentityUtils.GetManagedIdentityTokenResponseAsync(resource: "https://storage.azure.com/").GetAwaiter().GetResult();

ServerManagedIdentityTokenResponse tokenResponse = await ServerManagedIdentityUtils.GetManagedIdentityTokenResponseAsync(resource: ResourceStorageUri);

var token = tokenResponse.AccessToken;

applicationId = ServerManagedIdentityTokenHelper.GetTokenOid(token);
tenantId = ServerManagedIdentityTokenHelper.GetTokenTenantId(token);
}
catch (Exception)
{
Expand All @@ -90,7 +91,7 @@ public Guid GetServerApplicationId(LocalServerType localServerType, bool throwIf
TraceLog($"{nameof(EnableMIChecking)} is off.", EventLevel.Informational);
}

return applicationId;
return new ServerApplicationIdentity(applicationId, tenantId);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public static class ServerManagedIdentityTokenHelper
/// <summary>
/// Gets the oid claim from the token payload
/// </summary>
/// <param name="token"> the access token </param>
/// <returns> true, oid if successfully parsed, else return false, guid.empty </returns>
/// <param name="token">The access token </param>
/// <returns> The oid as a Guid if successfully parsed, otherwise throws an exception </returns>
public static Guid GetTokenOid(string token)
{
// try to deserialize the json string to aadtoken object
Expand All @@ -25,6 +25,22 @@ public static Guid GetTokenOid(string token)
// parse the oid string to guid object
return Guid.Parse(aadToken?.Oid);
}
/// <summary>
/// Gets the tenantId claim from the token payload
/// </summary>
/// <param name="token"> The access token </param>
/// <returns> The tenantId as a Guid if successfully parsed, otherwise throws an exception </returns>
public static Guid GetTokenTenantId(string token)
{
// try to deserialize the json string to aadtoken object
var aadToken = TryGetAadTokenFromAccessTokenString(token);

if(!Guid.TryParse(aadToken.TenantId, out Guid tenantId))
{
throw new ArgumentException("Token TenantId is invalid");
}
return tenantId;
}

/// <summary>
/// Try to get the Managed Identity type based on the given token response.
Expand Down Expand Up @@ -106,12 +122,17 @@ public class AadToken

[JsonProperty(PropertyName = ManagedIdentityClaimNames.ManagedIdentityResourceId)]
public string MIResourceId { get; set; }

[JsonProperty(PropertyName = ManagedIdentityClaimNames.TenantId)]
public string TenantId { get; set; }
}

public static class ManagedIdentityClaimNames
{
public const string Oid = "oid";

public const string ManagedIdentityResourceId = "xms_mirid";

public const string TenantId = "tid";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ private static bool IsSecretFilePathValid(string secretFilePath)

// Expected form: %ProgramData%\AzureConnectedMachineAgent\Tokens\<guid>.key
var programData = Environment.GetEnvironmentVariable("ProgramData");

if (string.IsNullOrEmpty(programData))
{
// If ProgramData is not found, try to manually construct it using SystemDrive
Expand Down Expand Up @@ -410,4 +410,4 @@ private static LocalServerType GetLocalServerTypeFromRegistry()
return LocalServerType.HybridServer;
}
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,7 @@
<data name="AgentMI_ProgramDataNotFoundError" xml:space="preserve">
<value>GetEnvironmentVariable failed to find 'ProgramData'</value>
</data>
<data name="MissingAzureContextTenantId" xml:space="preserve">
<value>The given azure context does not have tenant id.</value>
</data>
</root>
Loading