Skip to content
This repository was archived by the owner on Aug 29, 2025. It is now read-only.

Commit 79db8c9

Browse files
calebkiagebaywet
andauthored
fix: only fetch valid X.509 certificates from the certificate store (#425)
* fix: only fetch valid X.509 certificates from the certificate store perf: remove linq code use to improve code generation * fix: check certificate validity period is valid * chore: add certificate selection tests. chore: remove redundant certificate date checks * Update src/Microsoft.Graph.Cli.Core/Authentication/ClientCertificateCredentialFactory.cs Co-authored-by: Vincent Biret <[email protected]> --------- Co-authored-by: Vincent Biret <[email protected]>
1 parent ce464be commit 79db8c9

File tree

3 files changed

+117
-27
lines changed

3 files changed

+117
-27
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using System;
2+
using System.Security.Cryptography;
3+
using System.Security.Cryptography.X509Certificates;
4+
using Microsoft.Graph.Cli.Core.Authentication;
5+
using Xunit;
6+
7+
namespace Microsoft.Graph.Cli.Core.Tests.Authentication;
8+
9+
public class ClientCertificateCredentialFactoryTests
10+
{
11+
[Fact]
12+
public void ReturnsNullWhenStoreIsEmpty()
13+
{
14+
var store = new X509Certificate2Collection();
15+
16+
var result = ClientCertificateCredentialFactory.FindLatestByValidity(store);
17+
18+
Assert.Null(result);
19+
}
20+
21+
[Fact]
22+
public void ReturnsLatestValidCertificate()
23+
{
24+
var store = new X509Certificate2Collection
25+
{
26+
GenerateSelfSignedCertificate("1", new DateTimeOffset(2020, 1, 2, 0, 0, 0, TimeSpan.Zero),
27+
new DateTimeOffset(2027, 1, 1, 0, 0, 0, TimeSpan.Zero)),
28+
GenerateSelfSignedCertificate("2", new DateTimeOffset(2020, 1, 1, 0, 0, 0, TimeSpan.Zero),
29+
new DateTimeOffset(2027, 1, 1, 0, 0, 0, TimeSpan.Zero))
30+
};
31+
32+
var result = ClientCertificateCredentialFactory.FindLatestByValidity(store);
33+
34+
Assert.NotNull(result);
35+
Assert.Equal("CN=1", result.SubjectName.Name);
36+
}
37+
38+
private static X509Certificate2 GenerateSelfSignedCertificate(string subjectName, DateTimeOffset notBefore,
39+
DateTimeOffset notAfter)
40+
{
41+
if (notAfter < notBefore)
42+
{
43+
throw new ArgumentException("notAfter must be after notBefore");
44+
}
45+
46+
const string secp256R1Oid = "1.2.840.10045.3.1.7";
47+
var ecdsa = ECDsa.Create(ECCurve.CreateFromValue(secp256R1Oid));
48+
var certRequest = new CertificateRequest($"CN={subjectName}", ecdsa, HashAlgorithmName.SHA256);
49+
var generatedCert = certRequest.CreateSelfSigned(notBefore, notAfter);
50+
return generatedCert;
51+
}
52+
}

src/Microsoft.Graph.Cli.Core/Authentication/ClientCertificateCredentialFactory.cs

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
using System;
2-
using System.Linq;
32
using System.Security;
43
using System.Security.Cryptography;
54
using System.Security.Cryptography.X509Certificates;
65
using Azure.Identity;
7-
using Microsoft.Graph.Cli.Core.Utils;
86

97
namespace Microsoft.Graph.Cli.Core.Authentication;
108

@@ -23,7 +21,8 @@ public static class ClientCertificateCredentialFactory
2321
/// <param name="authorityHost">The entra authentication endpoint (to use with national clouds)</param>
2422
/// <returns>A ClientCertificateCredential</returns>
2523
/// <exception cref="ArgumentNullException">When a null url is provided for the authority host.</exception>
26-
public static ClientCertificateCredential GetClientCertificateCredential(string? tenantId, string? clientId, string? certificateName, string? certificateThumbPrint, Uri authorityHost)
24+
public static ClientCertificateCredential GetClientCertificateCredential(string? tenantId, string? clientId,
25+
string? certificateName, string? certificateThumbPrint, Uri authorityHost)
2726
{
2827
if (string.IsNullOrWhiteSpace(certificateName) && string.IsNullOrWhiteSpace(certificateThumbPrint))
2928
{
@@ -46,11 +45,13 @@ public static ClientCertificateCredential GetClientCertificateCredential(string?
4645

4746
X509Certificate2? certificate;
4847

49-
if (!string.IsNullOrWhiteSpace(certificateName) && TryGetCertificateFromStore(certificateName, isThumbPrint: false, out certificate))
48+
if (!string.IsNullOrWhiteSpace(certificateName) &&
49+
TryGetCertificateFromStore(certificateName, isThumbPrint: false, out certificate))
5050
{
5151
return new ClientCertificateCredential(tenantId, clientId, certificate, credOptions);
5252
}
53-
else if (!string.IsNullOrWhiteSpace(certificateThumbPrint) && TryGetCertificateFromStore(certificateThumbPrint, isThumbPrint: true, out certificate))
53+
else if (!string.IsNullOrWhiteSpace(certificateThumbPrint) &&
54+
TryGetCertificateFromStore(certificateThumbPrint, isThumbPrint: true, out certificate))
5455
{
5556
return new ClientCertificateCredential(tenantId, clientId, certificate, credOptions);
5657
}
@@ -65,7 +66,8 @@ public static ClientCertificateCredential GetClientCertificateCredential(string?
6566
/// <param name="isThumbPrint">If true, try to find the certificate by the thumb print.</param>
6667
/// <param name="certificate">A matching unexpired certificate from the store.</param>
6768
/// <returns>Returns true if the certificate was fetched successfully.</returns>
68-
internal static bool TryGetCertificateFromStore(string certificateNameOrThumbPrint, bool isThumbPrint, out X509Certificate2? certificate)
69+
internal static bool TryGetCertificateFromStore(string certificateNameOrThumbPrint, bool isThumbPrint,
70+
out X509Certificate2? certificate)
6971
{
7072
bool result = false;
7173
certificate = null;
@@ -74,19 +76,18 @@ internal static bool TryGetCertificateFromStore(string certificateNameOrThumbPri
7476
try
7577
{
7678
store.Open(OpenFlags.ReadOnly);
77-
78-
// If using a certificate with a trusted root you do not need to FindByTimeValid, instead:
79-
// currentCerts.Find(X509FindType.FindBySubjectDistinguishedName, certName, true);
80-
X509Certificate2Collection signingCerts = store.Certificates.Find(X509FindType.FindByTimeValid, DateTime.Now, false)
81-
.Find(isThumbPrint ? X509FindType.FindByThumbprint : X509FindType.FindBySubjectDistinguishedName, certificateNameOrThumbPrint, false);
79+
X509Certificate2Collection signingCerts = store.Certificates
80+
.Find(X509FindType.FindByTimeValid, DateTime.Now, false)
81+
.Find(isThumbPrint ? X509FindType.FindByThumbprint : X509FindType.FindBySubjectDistinguishedName,
82+
certificateNameOrThumbPrint, false);
8283
if (signingCerts.Count == 0)
8384
{
8485
result = false;
8586
}
8687
else
8788
{
88-
// Return the first certificate in the collection, has the right name and is current.
89-
certificate = signingCerts.OrderByDescending(static c => c.NotBefore).FirstOrDefault();
89+
certificate = FindLatestByValidity(signingCerts);
90+
9091
result = true;
9192
}
9293
}
@@ -98,7 +99,7 @@ internal static bool TryGetCertificateFromStore(string certificateNameOrThumbPri
9899
}
99100
catch (SecurityException)
100101
{
101-
// Isufficient permissions to read the store
102+
// Insufficient permissions to read the store
102103
result = false;
103104
}
104105
catch (ArgumentException)
@@ -113,12 +114,30 @@ internal static bool TryGetCertificateFromStore(string certificateNameOrThumbPri
113114
certificate.Dispose();
114115
certificate = null;
115116
}
117+
116118
store.Close();
117119
}
118120

119121
return result;
120122
}
121123

124+
internal static X509Certificate2? FindLatestByValidity(X509Certificate2Collection signingCerts)
125+
{
126+
X509Certificate2? certificate = null;
127+
// Return the first certificate in the collection, has the right name and is current.
128+
foreach (var cert in signingCerts)
129+
{
130+
// Use this certificate if it became valid after the currently selected one.
131+
if (certificate is null || cert.NotBefore > certificate.NotBefore)
132+
{
133+
certificate?.Dispose();
134+
certificate = cert;
135+
}
136+
}
137+
138+
return certificate;
139+
}
140+
122141
/// <summary>
123142
/// Opens a certificate file.
124143
/// </summary>

src/Microsoft.Graph.Cli.Core/Http/LoggingHandler.cs

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using System;
22
using System.Collections.Generic;
3-
using System.Linq;
43
using System.Net.Http;
54
using System.Net.Http.Headers;
5+
using System.Text;
66
using System.Threading;
77
using System.Threading.Tasks;
88
using Microsoft.Extensions.Logging;
@@ -47,29 +47,48 @@ protected override async Task<HttpResponseMessage> SendAsync(
4747

4848
private static string HeadersToString(in HttpHeaders headers, in HttpContentHeaders? contentHeaders)
4949
{
50-
if (!headers.Any() && contentHeaders?.Any() == false) return string.Empty;
51-
static string selector(KeyValuePair<string, IEnumerable<string>> h)
50+
var headersEnumerator = headers.GetEnumerator();
51+
var contentHeadersEnumerator = contentHeaders?.GetEnumerator();
52+
if (!headersEnumerator.MoveNext() && contentHeadersEnumerator?.MoveNext() == false) return string.Empty;
53+
headersEnumerator.Dispose();
54+
contentHeadersEnumerator?.Dispose();
55+
static void StringifyHeader(string name, IEnumerable<string> values, in StringBuilder sb)
5256
{
53-
var value = string.Join(",", h.Value);
54-
if (h.Key.Contains("Authorization", StringComparison.OrdinalIgnoreCase))
57+
sb.Append(name);
58+
sb.Append(':');
59+
sb.Append(' ');
60+
if (name.Contains("Authorization", StringComparison.OrdinalIgnoreCase))
5561
{
56-
value = "[PROTECTED]";
62+
sb.Append("[PROTECTED]");
5763
}
58-
return $"{h.Key}: {value}\n";
59-
};
64+
else
65+
{
66+
foreach (var value in values)
67+
{
68+
sb.Append(value);
69+
sb.Append(',');
70+
}
6071

61-
static string aggregator(string a, string b)
72+
sb.Remove(sb.Length - 1, 1);
73+
}
74+
}
75+
static void JoinHeaders(IEnumerable<KeyValuePair<string, IEnumerable<string>>> headers, in StringBuilder sb)
6276
{
63-
return string.Join(string.Empty, a, b);
77+
foreach (var header in headers)
78+
{
79+
StringifyHeader(header.Key, header.Value, sb);
80+
sb.Append('\n');
81+
}
6482
}
6583

66-
var h = headers.Select(selector).Aggregate("", aggregator);
84+
var sb = new StringBuilder(200);
85+
JoinHeaders(headers, sb);
6786
if (contentHeaders != null)
6887
{
69-
h += contentHeaders.Select(selector).Aggregate("", aggregator);
88+
JoinHeaders(contentHeaders, sb);
7089
}
7190

72-
return h;
91+
return sb.ToString();
7392
}
7493

7594
/// <summary>

0 commit comments

Comments
 (0)