diff --git a/src/code/ContainerRegistryServerAPICalls.cs b/src/code/ContainerRegistryServerAPICalls.cs index 973525daa..785c7aeae 100644 --- a/src/code/ContainerRegistryServerAPICalls.cs +++ b/src/code/ContainerRegistryServerAPICalls.cs @@ -38,7 +38,7 @@ internal class ContainerRegistryServerAPICalls : ServerApiCall private static readonly FindResults emptyResponseResults = new FindResults(stringResponse: Utils.EmptyStrArray, hashtableResponse: emptyHashResponses, responseType: containerRegistryFindResponseType); const string containerRegistryRefreshTokenTemplate = "grant_type=access_token&service={0}&tenant={1}&access_token={2}"; // 0 - registry, 1 - tenant, 2 - access token - const string containerRegistryAccessTokenTemplate = "grant_type=refresh_token&service={0}&scope=repository:*:*&refresh_token={1}"; // 0 - registry, 1 - refresh token + const string containerRegistryAccessTokenTemplate = "grant_type=refresh_token&service={0}&scope=repository:*:*&scope=registry:catalog:*&refresh_token={1}"; // 0 - registry, 1 - refresh token const string containerRegistryOAuthExchangeUrlTemplate = "https://{0}/oauth2/exchange"; // 0 - registry const string containerRegistryOAuthTokenUrlTemplate = "https://{0}/oauth2/token"; // 0 - registry const string containerRegistryManifestUrlTemplate = "https://{0}/v2/{1}/manifests/{2}"; // 0 - registry, 1 - repo(modulename), 2 - tag(version) @@ -46,6 +46,7 @@ internal class ContainerRegistryServerAPICalls : ServerApiCall const string containerRegistryFindImageVersionUrlTemplate = "https://{0}/v2/{1}/tags/list"; // 0 - registry, 1 - repo(modulename) const string containerRegistryStartUploadTemplate = "https://{0}/v2/{1}/blobs/uploads/"; // 0 - registry, 1 - packagename const string containerRegistryEndUploadTemplate = "https://{0}{1}&digest=sha256:{2}"; // 0 - registry, 1 - location, 2 - digest + const string defaultScope = "&scope=repository:*:*&scope=registry:catalog:*"; const string containerRegistryRepositoryListTemplate = "https://{0}/v2/_catalog"; // 0 - registry #endregion @@ -392,12 +393,18 @@ internal string GetContainerRegistryAccessToken(out ErrorRecord errRecord) } else { - bool isRepositoryUnauthenticated = IsContainerRegistryUnauthenticated(Repository.Uri.ToString(), out errRecord); + bool isRepositoryUnauthenticated = IsContainerRegistryUnauthenticated(Repository.Uri.ToString(), out errRecord, out accessToken); if (errRecord != null) { return null; } + if (!string.IsNullOrEmpty(accessToken)) + { + _cmdletPassedIn.WriteVerbose("Anonymous access token retrieved."); + return accessToken; + } + if (!isRepositoryUnauthenticated) { accessToken = Utils.GetAzAccessToken(); @@ -437,15 +444,82 @@ internal string GetContainerRegistryAccessToken(out ErrorRecord errRecord) /// /// Checks if container registry repository is unauthenticated. /// - internal bool IsContainerRegistryUnauthenticated(string containerRegistyUrl, out ErrorRecord errRecord) + internal bool IsContainerRegistryUnauthenticated(string containerRegistyUrl, out ErrorRecord errRecord, out string anonymousAccessToken) { _cmdletPassedIn.WriteDebug("In ContainerRegistryServerAPICalls::IsContainerRegistryUnauthenticated()"); errRecord = null; + anonymousAccessToken = string.Empty; string endpoint = $"{containerRegistyUrl}/v2/"; HttpResponseMessage response; try { response = _sessionClient.SendAsync(new HttpRequestMessage(HttpMethod.Head, endpoint)).Result; + + if (response.StatusCode == HttpStatusCode.Unauthorized) + { + // check if there is a auth challenge header + if (response.Headers.WwwAuthenticate.Count() > 0) + { + var authHeader = response.Headers.WwwAuthenticate.First(); + if (authHeader.Scheme == "Bearer") + { + // check if there is a realm + if (authHeader.Parameter.Contains("realm")) + { + // get the realm + var realm = authHeader.Parameter.Split(',')?.Where(x => x.Contains("realm"))?.FirstOrDefault()?.Split('=')[1]?.Trim('"'); + // get the service + var service = authHeader.Parameter.Split(',')?.Where(x => x.Contains("service"))?.FirstOrDefault()?.Split('=')[1]?.Trim('"'); + + if (string.IsNullOrEmpty(realm) || string.IsNullOrEmpty(service)) + { + errRecord = new ErrorRecord( + new InvalidOperationException("Failed to get realm or service from the auth challenge header."), + "RegistryUnauthenticationCheckError", + ErrorCategory.InvalidResult, + this); + + return false; + } + + string content = "grant_type=access_token&service=" + service + defaultScope; + var contentHeaders = new Collection> { new KeyValuePair("Content-Type", "application/x-www-form-urlencoded") }; + + // get the anonymous access token + var url = $"{realm}?service={service}{defaultScope}"; + + // we dont check the errorrecord here because we want to return false if we get a 401 and not throw an error + var results = GetHttpResponseJObjectUsingContentHeaders(url, HttpMethod.Get, content, contentHeaders, out _); + + if (results == null) + { + _cmdletPassedIn.WriteDebug("Failed to get access token from the realm. results is null."); + return false; + } + + if (results["access_token"] == null) + { + _cmdletPassedIn.WriteDebug($"Failed to get access token from the realm. access_token is null. results: {results}"); + return false; + } + + anonymousAccessToken = results["access_token"].ToString(); + _cmdletPassedIn.WriteDebug("Anonymous access token retrieved"); + return true; + } + } + } + } + } + catch (HttpRequestException hre) + { + errRecord = new ErrorRecord( + hre, + "RegistryAnonymousAcquireError", + ErrorCategory.ConnectionError, + this); + + return false; } catch (Exception e) { @@ -1756,7 +1830,7 @@ private FindResults FindPackages(string packageName, bool includePrerelease, out } // This remove the 'psresource/' prefix from the repository name for comparison with wildcard. - string moduleName = repositoryName.Substring(11); + string moduleName = repositoryName.StartsWith("psresource/") ? repositoryName.Substring(11) : repositoryName; WildcardPattern wildcardPattern = new WildcardPattern(packageName, WildcardOptions.IgnoreCase); diff --git a/test/FindPSResourceTests/FindPSResourceContainerRegistryServer.Tests.ps1 b/test/FindPSResourceTests/FindPSResourceContainerRegistryServer.Tests.ps1 index 3b403b24d..8efa635a9 100644 --- a/test/FindPSResourceTests/FindPSResourceContainerRegistryServer.Tests.ps1 +++ b/test/FindPSResourceTests/FindPSResourceContainerRegistryServer.Tests.ps1 @@ -151,12 +151,11 @@ Describe 'Test HTTP Find-PSResource for ACR Server Protocol' -tags 'CI' { $err[0].FullyQualifiedErrorId | Should -BeExactly "FindCommandOrDscResourceFailure,Microsoft.PowerShell.PSResourceGet.Cmdlets.FindPSResource" } - It "Should not find all resources given Name '*'" { + It "Should find all resources given Name '*'" { # FindAll() $res = Find-PSResource -Name "*" -Repository $ACRRepoName -ErrorVariable err -ErrorAction SilentlyContinue - $res | Should -BeNullOrEmpty - $err.Count | Should -BeGreaterThan 0 - $err[0].FullyQualifiedErrorId | Should -BeExactly "FindAllFailure,Microsoft.PowerShell.PSResourceGet.Cmdlets.FindPSResource" + $res | Should -Not -BeNullOrEmpty + $res.Count | Should -BeGreaterThan 0 } It "Should find script given Name" { @@ -268,3 +267,32 @@ Describe 'Test Find-PSResource for MAR Repository' -tags 'CI' { $res.Count | Should -BeGreaterThan 1 } } + +# Skip this test fo +Describe 'Test Find-PSResource for unauthenticated ACR repository' -tags 'CI' { + BeforeAll { + $skipOnWinPS = $PSVersionTable.PSVersion.Major -eq 5 + + if (-not $skipOnWinPS) { + Register-PSResourceRepository -Name "Unauthenticated" -Uri "https://psresourcegetnoauth.azurecr.io/" -ApiVersion "ContainerRegistry" + } + } + + AfterAll { + if (-not $skipOnWinPS) { + Unregister-PSResourceRepository -Name "Unauthenticated" + } + } + + It "Should find resource given specific Name, Version null" { + + if ($skipOnWinPS) { + Set-ItResult -Pending -Because "Skipping test on Windows PowerShell" + return + } + + $res = Find-PSResource -Name "hello-world" -Repository "Unauthenticated" + $res.Name | Should -Be "hello-world" + $res.Version | Should -Be "5.0.0" + } +}