Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>com.uid2</groupId>
<artifactId>uid2-shared</artifactId>
<version>8.0.32</version>
<version>8.0.34-alpha-191-SNAPSHOT</version>
<name>${project.groupId}:${project.artifactId}</name>
<description>Library for all the shared uid2 operations</description>
<url>https://github.com/IABTechLab/uid2docs</url>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,38 @@ public class AzureCCCoreAttestationService implements ICoreAttestationService {

private final IPolicyValidator policyValidator;

public AzureCCCoreAttestationService(String maaServerBaseUrl, String attestationUrl) {
this(new MaaTokenSignatureValidator(maaServerBaseUrl), new PolicyValidator(attestationUrl));
private final String azureCcProtocol;

public AzureCCCoreAttestationService(String maaServerBaseUrl, String attestationUrl, String azureCcProtocol) {
this(new MaaTokenSignatureValidator(maaServerBaseUrl), new PolicyValidator(attestationUrl), azureCcProtocol);
}

// used in UT
protected AzureCCCoreAttestationService(IMaaTokenSignatureValidator tokenSignatureValidator, IPolicyValidator policyValidator) {
protected AzureCCCoreAttestationService(IMaaTokenSignatureValidator tokenSignatureValidator, IPolicyValidator policyValidator, String azureCcProtocol) {
this.tokenSignatureValidator = tokenSignatureValidator;
this.policyValidator = policyValidator;
this.azureCcProtocol = azureCcProtocol;
}

@Override
public void attest(byte[] attestationRequest, byte[] publicKey, Handler<AsyncResult<AttestationResult>> handler) {
try {
var tokenString = new String(attestationRequest, StandardCharsets.US_ASCII);

log.debug("Attesting for {} operator...", azureCcProtocol);
log.debug("Validating signature...");
var tokenPayload = tokenSignatureValidator.validate(tokenString);
var tokenPayload = tokenSignatureValidator.validate(tokenString, azureCcProtocol);

log.debug("Validating policy...");
var encodedPublicKey = Utils.toBase64String(publicKey);

var enclaveId = policyValidator.validate(tokenPayload, encodedPublicKey);

if (allowedEnclaveIds.contains(enclaveId)) {
log.info("Successfully attested azure-cc against registered enclaves, enclave id: " + enclaveId);
log.info("Successfully attested {} against registered enclaves, enclave id: {}", azureCcProtocol, enclaveId);
handler.handle(Future.succeededFuture(new AttestationResult(publicKey, enclaveId)));
} else {
log.warn("Got unsupported azure-cc enclave id: " + enclaveId);
log.warn("Got unsupported {} enclave id: {}", azureCcProtocol, enclaveId);
handler.handle(Future.succeededFuture(new AttestationResult(AttestationFailure.FORBIDDEN_ENCLAVE)));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ public interface IMaaTokenSignatureValidator {
* @return Parsed token payload.
* @throws AttestationException
*/
MaaTokenPayload validate(String tokenString) throws AttestationException;
MaaTokenPayload validate(String tokenString, String protocol) throws AttestationException;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we introduce an enum for the protocols instead of using a String?

}
19 changes: 17 additions & 2 deletions src/main/java/com/uid2/shared/secure/azurecc/MaaTokenPayload.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
package com.uid2.shared.secure.azurecc;

import com.uid2.shared.secure.AttestationClientException;
import com.uid2.shared.secure.AttestationException;
import com.uid2.shared.secure.AttestationFailure;
import lombok.Builder;
import lombok.Value;

@Value
@Builder(toBuilder = true)
public class MaaTokenPayload {
public static final String SEV_SNP_VM_TYPE = "sevsnpvm";
public static final String AZURE_CC_ACI_PROTOCOL = "azure-cc";
public static final String AZURE_CC_AKS_PROTOCOL = "azure-cc-aks";
// the `x-ms-compliance-status` value for ACI CC
public static final String AZURE_COMPLIANT_UVM = "azure-compliant-uvm";
// the `x-ms-compliance-status` value for AKS CC
public static final String AZURE_COMPLIANT_UVM_AKS = "azure-signed-katacc-uvm";

private String azureProtocol;
private String attestationType;
private String complianceStatus;
private boolean vmDebuggable;
Expand All @@ -20,7 +29,13 @@ public boolean isSevSnpVM(){
return SEV_SNP_VM_TYPE.equalsIgnoreCase(attestationType);
}

public boolean isUtilityVMCompliant(){
return AZURE_COMPLIANT_UVM.equalsIgnoreCase(complianceStatus);
public boolean isUtilityVMCompliant() throws AttestationClientException {
if (azureProtocol == AZURE_CC_ACI_PROTOCOL) {
return AZURE_COMPLIANT_UVM.equalsIgnoreCase(complianceStatus);
} else if (azureProtocol == AZURE_CC_AKS_PROTOCOL) {
return AZURE_COMPLIANT_UVM_AKS.equalsIgnoreCase(complianceStatus);
} else {
throw new AttestationClientException(String.format("Azure protocol: %s not supported", azureProtocol), AttestationFailure.INVALID_PROTOCOL);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import static com.uid2.shared.secure.JwtUtils.tryGetField;

public class MaaTokenSignatureValidator implements IMaaTokenSignatureValidator {

// set to true to facilitate local test.
public static final boolean BYPASS_SIGNATURE_CHECK = false;

Expand Down Expand Up @@ -52,7 +51,7 @@ private TokenVerifier buildTokenVerifier(String kid) throws AttestationException
}

@Override
public MaaTokenPayload validate(String tokenString) throws AttestationException {
public MaaTokenPayload validate(String tokenString, String protocol) throws AttestationException {
if (Strings.isNullOrEmpty(tokenString)) {
throw new IllegalArgumentException("tokenString can not be null or empty");
}
Expand All @@ -77,6 +76,7 @@ public MaaTokenPayload validate(String tokenString) throws AttestationException

var tokenPayloadBuilder = MaaTokenPayload.builder();

tokenPayloadBuilder.azureProtocol(protocol);
tokenPayloadBuilder.attestationType(tryGetField(rawPayload, "x-ms-attestation-type", String.class));
tokenPayloadBuilder.complianceStatus(tryGetField(rawPayload, "x-ms-compliance-status", String.class));
tokenPayloadBuilder.vmDebuggable(tryGetField(rawPayload, "x-ms-sevsnpvm-is-debuggable", Boolean.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
Expand All @@ -18,6 +22,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -55,25 +60,27 @@ private static byte[] encodeStringUnicodeAttestationEndpoint(String data) {

@BeforeEach
public void setup() throws AttestationException {
when(alwaysPassTokenValidator.validate(any())).thenReturn(VALID_TOKEN_PAYLOAD);
when(alwaysPassTokenValidator.validate(any(), any())).thenReturn(VALID_TOKEN_PAYLOAD);
when(alwaysPassPolicyValidator.validate(any(), any())).thenReturn(ENCLAVE_ID);
}

@Test
public void testHappyPath() throws AttestationException {
var provider = new AzureCCCoreAttestationService(alwaysPassTokenValidator, alwaysPassPolicyValidator);
@ParameterizedTest
@MethodSource("argumentProvider")
public void testHappyPath(String azureProtocol) throws AttestationException {
var provider = new AzureCCCoreAttestationService(alwaysPassTokenValidator, alwaysPassPolicyValidator, azureProtocol);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
assertTrue(ar.succeeded());
assertTrue(ar.result().isSuccess());
});
}

@Test
public void testSignatureCheckFailed_ClientError() throws AttestationException {
@ParameterizedTest
@MethodSource("argumentProvider")
public void testSignatureCheckFailed_ClientError(String azureProtocol) throws AttestationException {
var errorStr = "token signature validation failed";
when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationClientException(errorStr, AttestationFailure.BAD_PAYLOAD));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator);
when(alwaysFailTokenValidator.validate(any(), any())).thenThrow(new AttestationClientException(errorStr, AttestationFailure.BAD_PAYLOAD));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator, azureProtocol);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
assertTrue(ar.succeeded());
Expand All @@ -82,22 +89,24 @@ public void testSignatureCheckFailed_ClientError() throws AttestationException {
});
}

@Test
public void testSignatureCheckFailed_ServerError() throws AttestationException {
when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationException("unknown server error"));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator);
@ParameterizedTest
@MethodSource("argumentProvider")
public void testSignatureCheckFailed_ServerError(String azureProtocol) throws AttestationException {
when(alwaysFailTokenValidator.validate(any(), any())).thenThrow(new AttestationException("unknown server error"));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator, azureProtocol);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
assertFalse(ar.succeeded());
assertTrue(ar.cause() instanceof AttestationException);
});
}

@Test
public void testPolicyCheckSuccess_ClientError() throws AttestationException {
@ParameterizedTest
@MethodSource("argumentProvider")
public void testPolicyCheckSuccess_ClientError(String azureProtocol) throws AttestationException {
var errorStr = "policy validation failed";
when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationClientException(errorStr, AttestationFailure.BAD_PAYLOAD));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysFailPolicyValidator);
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysFailPolicyValidator, azureProtocol);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
assertTrue(ar.succeeded());
Expand All @@ -106,20 +115,22 @@ public void testPolicyCheckSuccess_ClientError() throws AttestationException {
});
}

@Test
public void testPolicyCheckFailed_ServerError() throws AttestationException {
@ParameterizedTest
@MethodSource("argumentProvider")
public void testPolicyCheckFailed_ServerError(String azureProtocol) throws AttestationException {
when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationException("unknown server error"));
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysFailPolicyValidator);
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysFailPolicyValidator, azureProtocol);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
assertFalse(ar.succeeded());
assertTrue(ar.cause() instanceof AttestationException);
});
}

@Test
public void testEnclaveNotRegistered() throws AttestationException {
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator);
@ParameterizedTest
@MethodSource("argumentProvider")
public void testEnclaveNotRegistered(String azureProtocol) throws AttestationException {
var provider = new AzureCCCoreAttestationService(alwaysFailTokenValidator, alwaysPassPolicyValidator, azureProtocol);
attest(provider, ar -> {
assertTrue(ar.succeeded());
assertFalse(ar.result().isSuccess());
Expand All @@ -133,4 +144,11 @@ private static void attest(ICoreAttestationService provider, Handler<AsyncResult
PUBLIC_KEY.getBytes(StandardCharsets.UTF_8),
handler);
}

static Stream<Arguments> argumentProvider() {
return Stream.of(
Arguments.of(MaaTokenPayload.AZURE_CC_ACI_PROTOCOL),
Arguments.of(MaaTokenPayload.AZURE_CC_AKS_PROTOCOL)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@
import com.uid2.shared.secure.AttestationException;
import com.uid2.shared.secure.TestClock;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.stream.Stream;

import static com.uid2.shared.secure.TestUtils.loadFromJson;
import static com.uid2.shared.secure.azurecc.MaaTokenUtils.validateAndParseToken;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class MaaTokenSignatureValidatorTest {
@Test
public void testPayload() throws Exception {
@ParameterizedTest
@MethodSource("argumentProvider")
public void testPayload(String payloadPath, String protocol) throws Exception {
// expire at 1695313895
var payloadPath = "/com.uid2.shared/test/secure/azurecc/jwt_payload.json";
var payload = loadFromJson(payloadPath);
var clock = new TestClock();
clock.setCurrentTimeMs(1695313893000L);
Expand All @@ -22,7 +26,7 @@ public void testPayload() throws Exception {
var expectedLocation = "East US";
var expectedPublicKey = "abc";

var tokenPayload = validateAndParseToken(payload, clock);
var tokenPayload = validateAndParseToken(payload, clock, protocol);
assertEquals(true, tokenPayload.isSevSnpVM());
assertEquals(true, tokenPayload.isUtilityVMCompliant());
assertEquals(false, tokenPayload.isVmDebuggable());
Expand All @@ -37,6 +41,13 @@ public void testE2E() throws AttestationException {
var maaToken = "<Placeholder>";
var maaServerUrl = "https://sharedeus.eus.attest.azure.net";
var validator = new MaaTokenSignatureValidator(maaServerUrl);
var token = validator.validate(maaToken);
var token = validator.validate(maaToken, MaaTokenPayload.AZURE_CC_ACI_PROTOCOL);
}

static Stream<Arguments> argumentProvider() {
return Stream.of(
Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aci.json", MaaTokenPayload.AZURE_CC_ACI_PROTOCOL),
Arguments.of("/com.uid2.shared/test/secure/azurecc/jwt_payload_aks.json", MaaTokenPayload.AZURE_CC_AKS_PROTOCOL)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
public class MaaTokenUtils {
public static final String MAA_BASE_URL = "https://sharedeus.eus.attest.azure.net";

public static MaaTokenPayload validateAndParseToken(JsonObject payload, Clock clock) throws Exception{
public static MaaTokenPayload validateAndParseToken(JsonObject payload, Clock clock, String protocol) throws Exception{
var gen = KeyPairGenerator.getInstance(Const.Name.AsymetricEncryptionKeyClass);
gen.initialize(2048, new SecureRandom());
var keyPair = gen.generateKeyPair();
Expand All @@ -30,7 +30,7 @@ public static MaaTokenPayload validateAndParseToken(JsonObject payload, Clock cl
var tokenVerifier = new MaaTokenSignatureValidator(MAA_BASE_URL, keyProvider, clock);

// validate token
return tokenVerifier.validate(token);
return tokenVerifier.validate(token, protocol);
}

private static class MockKeyProvider implements IPublicKeyProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ private MaaTokenPayload generateBasicPayload() {
.vmDebuggable(false)
.runtimeData(generateBasicRuntimeData())
.ccePolicyDigest(CCE_POLICY_DIGEST)
.azureProtocol(MaaTokenPayload.AZURE_CC_ACI_PROTOCOL)
.build();
}

Expand Down Expand Up @@ -125,4 +126,53 @@ public void testValidationFailure_DifferentAttestationUrl() {
assertEquals(AttestationFailure.UNKNOWN_ATTESTATION_URL, ((AttestationClientException)t).getAttestationFailure());

}

@Test
public void testValidationFailure_AzureCcWithOtherUvm() {
var validator = new PolicyValidator(ATTESTATION_URL);
var aksPayload = generateBasicPayload()
.toBuilder()
.complianceStatus("fake-compliance")
.build();
Throwable t = assertThrows(AttestationException.class, ()-> validator.validate(aksPayload, PUBLIC_KEY));
assertEquals("Not run in Azure Compliance Utility VM", t.getMessage());
assertEquals(AttestationFailure.BAD_FORMAT, ((AttestationClientException)t).getAttestationFailure());
}

@Test
public void testValidationSuccess_AksWithAzureSignedKataccUvm() throws AttestationClientException {
var validator = new PolicyValidator(ATTESTATION_URL);
var aksPayload = generateBasicPayload()
.toBuilder()
.complianceStatus("azure-signed-katacc-uvm")
.azureProtocol(MaaTokenPayload.AZURE_CC_AKS_PROTOCOL)
.build();
var enclaveId = validator.validate(aksPayload, PUBLIC_KEY);
assertEquals(CCE_POLICY_DIGEST, enclaveId);
}

@Test
public void testValidationFailure_AksWithOtherUvm() {
var validator = new PolicyValidator(ATTESTATION_URL);
var aksPayload = generateBasicPayload()
.toBuilder()
.complianceStatus("fake-compliance")
.azureProtocol(MaaTokenPayload.AZURE_CC_AKS_PROTOCOL)
.build();
Throwable t = assertThrows(AttestationException.class, ()-> validator.validate(aksPayload, PUBLIC_KEY));
assertEquals("Not run in Azure Compliance Utility VM", t.getMessage());
assertEquals(AttestationFailure.BAD_FORMAT, ((AttestationClientException)t).getAttestationFailure());
}

@Test
public void testValidationFailure_InvalidProtocol() {
var validator = new PolicyValidator(ATTESTATION_URL);
var aksPayload = generateBasicPayload()
.toBuilder()
.azureProtocol("fake-protocol")
.build();
Throwable t = assertThrows(AttestationException.class, ()-> validator.validate(aksPayload, PUBLIC_KEY));
assertEquals("Azure protocol: fake-protocol not supported", t.getMessage());
assertEquals(AttestationFailure.INVALID_PROTOCOL, ((AttestationClientException)t).getAttestationFailure());
}
}
Loading
Loading