Skip to content
82 changes: 78 additions & 4 deletions src/code/ContainerRegistryServerAPICalls.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ internal class ContainerRegistryServerAPICalls : ServerApiCall
private static readonly FindResults emptyResponseResults = new FindResults(stringResponse: Utils.EmptyStrArray, hashtableResponse: emptyHashResponses, responseType: containerRegistryFindResponseType);

const string containerRegistryRefreshTokenTemplate = "grant_type=access_token&service={0}&tenant={1}&access_token={2}"; // 0 - registry, 1 - tenant, 2 - access token
const string containerRegistryAccessTokenTemplate = "grant_type=refresh_token&service={0}&scope=repository:*:*&refresh_token={1}"; // 0 - registry, 1 - refresh token
const string containerRegistryAccessTokenTemplate = "grant_type=refresh_token&service={0}&scope=repository:*:*&scope=registry:catalog:*&refresh_token={1}"; // 0 - registry, 1 - refresh token
const string containerRegistryOAuthExchangeUrlTemplate = "https://{0}/oauth2/exchange"; // 0 - registry
const string containerRegistryOAuthTokenUrlTemplate = "https://{0}/oauth2/token"; // 0 - registry
const string containerRegistryManifestUrlTemplate = "https://{0}/v2/{1}/manifests/{2}"; // 0 - registry, 1 - repo(modulename), 2 - tag(version)
const string containerRegistryBlobDownloadUrlTemplate = "https://{0}/v2/{1}/blobs/{2}"; // 0 - registry, 1 - repo(modulename), 2 - layer digest
const string containerRegistryFindImageVersionUrlTemplate = "https://{0}/v2/{1}/tags/list"; // 0 - registry, 1 - repo(modulename)
const string containerRegistryStartUploadTemplate = "https://{0}/v2/{1}/blobs/uploads/"; // 0 - registry, 1 - packagename
const string containerRegistryEndUploadTemplate = "https://{0}{1}&digest=sha256:{2}"; // 0 - registry, 1 - location, 2 - digest
const string defaultScope = "&scope=repository:*:*&scope=registry:catalog:*";
const string containerRegistryRepositoryListTemplate = "https://{0}/v2/_catalog"; // 0 - registry

#endregion
Expand Down Expand Up @@ -392,12 +393,18 @@ internal string GetContainerRegistryAccessToken(out ErrorRecord errRecord)
}
else
{
bool isRepositoryUnauthenticated = IsContainerRegistryUnauthenticated(Repository.Uri.ToString(), out errRecord);
bool isRepositoryUnauthenticated = IsContainerRegistryUnauthenticated(Repository.Uri.ToString(), out errRecord, out accessToken);
if (errRecord != null)
{
return null;
}

if (!string.IsNullOrEmpty(accessToken))
{
_cmdletPassedIn.WriteVerbose("Anonymous access token retrieved.");
return accessToken;
}

if (!isRepositoryUnauthenticated)
{
accessToken = Utils.GetAzAccessToken();
Expand Down Expand Up @@ -437,15 +444,82 @@ internal string GetContainerRegistryAccessToken(out ErrorRecord errRecord)
/// <summary>
/// Checks if container registry repository is unauthenticated.
/// </summary>
internal bool IsContainerRegistryUnauthenticated(string containerRegistyUrl, out ErrorRecord errRecord)
internal bool IsContainerRegistryUnauthenticated(string containerRegistyUrl, out ErrorRecord errRecord, out string anonymousAccessToken)
{
_cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::IsContainerRegistryUnauthenticated()");
errRecord = null;
anonymousAccessToken = string.Empty;
string endpoint = $"{containerRegistyUrl}/v2/";
HttpResponseMessage response;
try
{
response = _sessionClient.SendAsync(new HttpRequestMessage(HttpMethod.Head, endpoint)).Result;

if (response.StatusCode == HttpStatusCode.Unauthorized)
{
// check if there is a auth challenge header
if (response.Headers.WwwAuthenticate.Count() > 0)
{
var authHeader = response.Headers.WwwAuthenticate.First();
if (authHeader.Scheme == "Bearer")
{
// check if there is a realm
if (authHeader.Parameter.Contains("realm"))
{
// get the realm
var realm = authHeader.Parameter.Split(',')?.Where(x => x.Contains("realm"))?.FirstOrDefault()?.Split('=')[1]?.Trim('"');
// get the service
var service = authHeader.Parameter.Split(',')?.Where(x => x.Contains("service"))?.FirstOrDefault()?.Split('=')[1]?.Trim('"');

if (string.IsNullOrEmpty(realm) || string.IsNullOrEmpty(service))
{
errRecord = new ErrorRecord(
new InvalidOperationException("Failed to get realm or service from the auth challenge header."),
"RegistryUnauthenticationCheckError",
ErrorCategory.InvalidResult,
this);

return false;
}

string content = "grant_type=access_token&service=" + service + defaultScope;
var contentHeaders = new Collection<KeyValuePair<string, string>> { new KeyValuePair<string, string>("Content-Type", "application/x-www-form-urlencoded") };

// get the anonymous access token
var url = $"{realm}?service={service}{defaultScope}";

// we dont check the errorrecord here because we want to return false if we get a 401 and not throw an error
var results = GetHttpResponseJObjectUsingContentHeaders(url, HttpMethod.Get, content, contentHeaders, out _);

if (results == null)
{
_cmdletPassedIn.WriteDebug("Failed to get access token from the realm. results is null.");
return false;
}

if (results["access_token"] == null)
{
_cmdletPassedIn.WriteDebug($"Failed to get access token from the realm. access_token is null. results: {results}");
return false;
}

anonymousAccessToken = results["access_token"].ToString();
_cmdletPassedIn.WriteDebug("Anonymous access token retrieved");
return true;
}
}
}
}
}
catch (HttpRequestException hre)
{
errRecord = new ErrorRecord(
hre,
"RegistryAnonymousAcquireError",
ErrorCategory.ConnectionError,
this);

return false;
}
catch (Exception e)
{
Expand Down Expand Up @@ -1756,7 +1830,7 @@ private FindResults FindPackages(string packageName, bool includePrerelease, out
}

// This remove the 'psresource/' prefix from the repository name for comparison with wildcard.
string moduleName = repositoryName.Substring(11);
string moduleName = repositoryName.StartsWith("psresource/") ? repositoryName.Substring(11) : repositoryName;

WildcardPattern wildcardPattern = new WildcardPattern(packageName, WildcardOptions.IgnoreCase);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,11 @@ Describe 'Test HTTP Find-PSResource for ACR Server Protocol' -tags 'CI' {
$err[0].FullyQualifiedErrorId | Should -BeExactly "FindCommandOrDscResourceFailure,Microsoft.PowerShell.PSResourceGet.Cmdlets.FindPSResource"
}

It "Should not find all resources given Name '*'" {
It "Should find all resources given Name '*'" {
# FindAll()
$res = Find-PSResource -Name "*" -Repository $ACRRepoName -ErrorVariable err -ErrorAction SilentlyContinue
$res | Should -BeNullOrEmpty
$err.Count | Should -BeGreaterThan 0
$err[0].FullyQualifiedErrorId | Should -BeExactly "FindAllFailure,Microsoft.PowerShell.PSResourceGet.Cmdlets.FindPSResource"
$res | Should -Not -BeNullOrEmpty
$res.Count | Should -BeGreaterThan 0
}

It "Should find script given Name" {
Expand Down Expand Up @@ -268,3 +267,32 @@ Describe 'Test Find-PSResource for MAR Repository' -tags 'CI' {
$res.Count | Should -BeGreaterThan 1
}
}

# Skip this test fo
Describe 'Test Find-PSResource for unauthenticated ACR repository' -tags 'CI' {
BeforeAll {
$skipOnWinPS = $PSVersionTable.PSVersion.Major -eq 5

if (-not $skipOnWinPS) {
Register-PSResourceRepository -Name "Unauthenticated" -Uri "https://psresourcegetnoauth.azurecr.io/" -ApiVersion "ContainerRegistry"
}
}

AfterAll {
if (-not $skipOnWinPS) {
Unregister-PSResourceRepository -Name "Unauthenticated"
}
}

It "Should find resource given specific Name, Version null" {

if ($skipOnWinPS) {
Set-ItResult -Pending -Because "Skipping test on Windows PowerShell"
return
}

$res = Find-PSResource -Name "hello-world" -Repository "Unauthenticated"
$res.Name | Should -Be "hello-world"
$res.Version | Should -Be "5.0.0"
}
}