diff --git a/src/azul/drs.py b/src/azul/drs.py index 9c1b774550..d988e01112 100644 --- a/src/azul/drs.py +++ b/src/azul/drs.py @@ -34,7 +34,6 @@ from azul import ( R, cache, - cached_property, mutable_furl, ) from azul.http import ( @@ -199,14 +198,6 @@ def parse(cls, drs_uri: str) -> 'DRSURI': else: raise - @abstractmethod - def to_url(self, client: 'DRSClient', access_id: str | None = None) -> furl: - """ - Translate this DRS URI into a URL of the DRS REST API endpoint at which - the file identified by this DRS URI can be accessed. - """ - raise NotImplementedError - @attr.s(auto_attribs=True, kw_only=True, frozen=True, slots=True) class HostBasedDRSURI(DRSURI): @@ -228,8 +219,8 @@ def parse(cls, drs_uri: str) -> Self: assert len(path) == 1, R('Invalid path', drs_uri) return cls(server=parsed_uri.netloc, object_id=path[0]) - def to_url(self, client: 'DRSClient', access_id: str | None = None) -> furl: - path = drs_object_url_path(object_id=self.object_id, access_id=access_id) + def to_url(self) -> furl: + path = drs_object_url_path(object_id=self.object_id) return furl(scheme='https', netloc=self.server, path=path) @@ -281,23 +272,23 @@ def parse(cls, drs_uri: str) -> Self: def _decode(cls, s: str) -> str: return urllib.parse.unquote(s, errors='strict') - def to_url(self, client: 'DRSClient', access_id: str | None = None) -> furl: + def to_url(self, id_client: 'IdentifiersDotOrgClient') -> furl: if self.provider_code is not None: raise NotImplementedError( 'Resolving compact identifier-based DRS URIs with ' 'provider codes is currently not supported', self ) - url = client.id_client.resolve(self.namespace, self.accession) + url = id_client.resolve(self.namespace, self.accession) # The URL pattern registered at identifiers.org ought to replicate the - # DRS spec, but we have to re-create the path using the spec because the - # registered pattern does not support embedding the access ID. + # DRS spec. If the response to a request to the returned URL includes an + # access ID, another request must be made to the returned URL followed + # by the string `/access/` and the ID. assert str(url.path) == drs_object_url_path(object_id=self.accession), R( 'Format of resolved URL is incompatible with the DRS specification', url) - url.set(path=drs_object_url_path(object_id=self.accession, access_id=access_id)) return url -class IdentifiersDotOrgClient(HasCachedHttpClient): +class _BaseClient(HasCachedHttpClient): def _create_http_client(self) -> urllib3.request.RequestMethods: return Propagate429HttpClient( @@ -306,6 +297,26 @@ def _create_http_client(self) -> urllib3.request.RequestMethods: ) ) + +class DRSClient(metaclass=ABCMeta): + + @abstractmethod + def drs_object(self, drs_url: furl) -> 'DRSObject': + raise NotImplementedError + + +class UnauthenticatedDRSClient(DRSClient, _BaseClient): + """ + A generic DRS client that does not send authentication to the server. + """ + + def drs_object(self, drs_url: furl) -> 'DRSObject': + return DRSObject(url=drs_url, + http_client=self._http_client) + + +class IdentifiersDotOrgClient(_BaseClient): + def resolve(self, prefix: str, accession: str) -> mutable_furl: namespace_id = self._prefix_to_namespace(prefix) log.info('Resolved prefix %r to namespace ID %r', prefix, namespace_id) @@ -343,26 +354,20 @@ def _api_request(self, path: str, **args) -> MutableJSON: @attr.s(auto_attribs=True, kw_only=True, frozen=True) -class DRSClient: +class DRSObject: _http_client: urllib3.request.RequestMethods + _url: furl - @cached_property - def id_client(self) -> IdentifiersDotOrgClient: - return IdentifiersDotOrgClient() - - def get_object(self, - drs_uri: str, - access_method: AccessMethod = AccessMethod.https - ) -> Access: + def get(self, access_method: AccessMethod = AccessMethod.https) -> Access: """ Returns access to the content of the data object identified by the given URI. The scheme of the URL in the returned access object depends on the access method specified. """ - return self._get_object(drs_uri, access_method) + return self._get(access_method) - def _get_object(self, drs_uri: str, access_method: AccessMethod) -> Access: - url = DRSURI.parse(drs_uri).to_url(self) + def _get(self, access_method: AccessMethod) -> Access: + url = self._url while True: response = self._request(url) if response.status == 200: @@ -381,9 +386,9 @@ def _get_object(self, drs_uri: str, access_method: AccessMethod) -> Access: # https://github.com/ga4gh/data-repository-service-schemas/issues/361 assert access_method is AccessMethod.gs, R( 'Unexpected access method', access_method) - return self._get_object_access(drs_uri, access_id, AccessMethod.https) + return self._get_access(access_id, AccessMethod.https) elif access_id is not None: - return self._get_object_access(drs_uri, access_id, access_method) + return self._get_access(access_id, access_method) elif access_url is not None: scheme = furl(access_url['url']).scheme assert scheme == access_method.scheme, R( @@ -400,12 +405,9 @@ def _get_object(self, drs_uri: str, access_method: AccessMethod) -> Access: else: raise DRSStatusException(url, response) - def _get_object_access(self, - drs_uri: str, - access_id: str, - access_method: AccessMethod - ) -> Access: - url = DRSURI.parse(drs_uri).to_url(self, access_id) + def _get_access(self, access_id: str, access_method: AccessMethod) -> Access: + url = self._url.copy() + url.path.add(['access', access_id]) while True: response = self._request(url) if response.status == 200: diff --git a/src/azul/indexer/mirror_service.py b/src/azul/indexer/mirror_service.py index 1b7eacd6ae..09c83312b7 100644 --- a/src/azul/indexer/mirror_service.py +++ b/src/azul/indexer/mirror_service.py @@ -722,8 +722,8 @@ def _repository_url(self, file: File) -> furl: 'Only TDR catalogs are supported', self.catalog) assert file.drs_uri is not None, R( 'File cannot be downloaded', file) - drs = self._repository_plugin.drs_client(authentication=None) - access = drs.get_object(file.drs_uri, AccessMethod.gs) + object = self._repository_plugin.drs_object(file.drs_uri) + access = object.get(AccessMethod.gs) assert access.method is AccessMethod.https, access return furl(access.url) diff --git a/src/azul/plugins/__init__.py b/src/azul/plugins/__init__.py index 3fc3af0342..c68393f697 100644 --- a/src/azul/plugins/__init__.py +++ b/src/azul/plugins/__init__.py @@ -23,6 +23,9 @@ ) import attrs +from furl import ( + furl, +) from more_itertools import ( one, ) @@ -43,7 +46,12 @@ Digest, ) from azul.drs import ( - DRSClient, + CompactDRSURI, + DRSObject, + DRSURI, + HostBasedDRSURI, + IdentifiersDotOrgClient, + UnauthenticatedDRSClient, ) from azul.indexer import ( Bundle, @@ -830,16 +838,36 @@ def list_files(self, source: SOURCE_REF, prefix: str) -> list['File']: raise NotImplementedError - @abstractmethod - def drs_client(self, + def drs_object(self, + drs_uri: str, authentication: Authentication | None = None - ) -> DRSClient: + ) -> DRSObject: """ Returns a DRS client that uses the given authentication with requests to the DRS server. If a concrete subclass doesn't support authentication, it should assert that the argument is ``None``. """ - raise NotImplementedError + assert authentication is None, type(authentication) + drs_url = self._resolve_drs_uri(drs_uri) + return self._unauthenticated_drs.drs_object(drs_url) + + def _resolve_drs_uri(self, drs_uri: str) -> furl: + drs_uri = DRSURI.parse(drs_uri) + if isinstance(drs_uri, CompactDRSURI): + drs_url = drs_uri.to_url(self._identifiers_dot_org) + elif isinstance(drs_uri, HostBasedDRSURI): + drs_url = drs_uri.to_url() + else: + assert False + return drs_url + + @cached_property + def _unauthenticated_drs(self) -> UnauthenticatedDRSClient: + return UnauthenticatedDRSClient() + + @cached_property + def _identifiers_dot_org(self) -> IdentifiersDotOrgClient: + return IdentifiersDotOrgClient() @abstractmethod def file_download_class(self) -> type['RepositoryFileDownload']: diff --git a/src/azul/plugins/repository/canned/__init__.py b/src/azul/plugins/repository/canned/__init__.py index cb68337603..1db67a6947 100644 --- a/src/azul/plugins/repository/canned/__init__.py +++ b/src/azul/plugins/repository/canned/__init__.py @@ -27,9 +27,6 @@ from azul.auth import ( Authentication, ) -from azul.drs import ( - DRSClient, -) from azul.http import ( HasCachedHttpClient, ) @@ -276,12 +273,6 @@ def _direct_file_url(self, def file_download_class(self) -> type[RepositoryFileDownload]: return CannedFileDownload - def drs_client(self, - authentication: Authentication | None = None - ) -> DRSClient: - assert authentication is None, type(authentication) - return DRSClient(http_client=self._http_client) - def validate_version(self, version: str) -> None: parse_dcp2_version(version) diff --git a/src/azul/plugins/repository/dss/__init__.py b/src/azul/plugins/repository/dss/__init__.py index 13e8cd8277..fc9d7b924d 100644 --- a/src/azul/plugins/repository/dss/__init__.py +++ b/src/azul/plugins/repository/dss/__init__.py @@ -32,9 +32,6 @@ from azul.deployment import ( aws, ) -from azul.drs import ( - DRSClient, -) from azul.http import ( HasCachedHttpClient, ) @@ -198,12 +195,6 @@ def _direct_file_url(self, url.query.add(adict(version=file_version, replica=replica, token=token)) return str(url) - def drs_client(self, - authentication: Authentication | None = None - ) -> DRSClient: - assert authentication is None, type(authentication) - return DRSClient(http_client=self._http_client) - def file_download_class(self) -> type[RepositoryFileDownload]: return DSSFileDownload diff --git a/src/azul/plugins/repository/tdr.py b/src/azul/plugins/repository/tdr.py index 54d543ba92..4593dc3d20 100644 --- a/src/azul/plugins/repository/tdr.py +++ b/src/azul/plugins/repository/tdr.py @@ -23,6 +23,7 @@ from azul import ( cache_per_thread, + config, require, ) from azul.auth import ( @@ -35,7 +36,7 @@ ) from azul.drs import ( AccessMethod, - DRSClient, + DRSObject, ) from azul.indexer import ( Bundle, @@ -177,13 +178,6 @@ def _user_authenticated_tdr(cls, type(authentication)) return tdr - @classmethod - @cache_per_thread - def _drs_client(cls, - authentication: Authentication | None = None - ) -> DRSClient: - return cls._user_authenticated_tdr(authentication).drs_client() - def _lookup_source_id(self, spec: TDRSourceSpec) -> str: return self.tdr.lookup_source(spec) @@ -209,10 +203,19 @@ def _full_table_name(self, source: TDRSourceSpec, table_name: str) -> str: def _emulate_bundle(self, bundle_fqid: TDRBundleFQID) -> TDR_BUNDLE: raise NotImplementedError - def drs_client(self, + def drs_object(self, + drs_uri: str, authentication: Authentication | None = None - ) -> DRSClient: - return self._drs_client(authentication) + ) -> DRSObject: + drs_url = self._resolve_drs_uri(drs_uri) + tdr_url = config.tdr_service_url + # Authenticate only if the DRS server is TDR so that we don't leak user + # or service account tokens to untrusted servers. + if (drs_url.scheme, drs_url.host) == (tdr_url.scheme, tdr_url.host): + drs_client = self._user_authenticated_tdr(authentication) + else: + drs_client = self._unauthenticated_drs + return drs_client.drs_object(drs_url) def file_download_class(self) -> type[RepositoryFileDownload]: return TDRFileDownload @@ -270,9 +273,8 @@ def update(self, assert self.location is None, self assert self.retry_after is None, self else: - drs_client = plugin.drs_client(authentication) - access = drs_client.get_object(self.file.drs_uri, - access_method=AccessMethod.gs) + drs_client = plugin.drs_object(self.file.drs_uri, authentication) + access = drs_client.get(access_method=AccessMethod.gs) require(access.method is AccessMethod.https, access.method) require(access.headers is None, access.headers) signed_url = access.url diff --git a/src/azul/service/download_controller.py b/src/azul/service/download_controller.py index 5157bc2c22..45d8b67558 100644 --- a/src/azul/service/download_controller.py +++ b/src/azul/service/download_controller.py @@ -16,6 +16,7 @@ BadRequestError, NotFoundError, TooManyRequestsError, + UnauthorizedError, ) from azul import ( @@ -33,6 +34,9 @@ from azul.collections import ( adict, ) +from azul.drs import ( + DRSStatusException, +) from azul.http import ( LimitedTimeoutException, TooManyRequestsException, @@ -168,6 +172,12 @@ def download_file(self, raise ServiceUnavailableError(*e.args) except TooManyRequestsException as e: raise TooManyRequestsError(*e.args) + except DRSStatusException as e: + msg, status, data = e.args + if status == UnauthorizedError.STATUS_CODE: + raise UnauthorizedError(msg) + else: + raise if download.retry_after is not None: retry_after = min(download.retry_after, int(1.3 ** request_index)) if wait is not None: diff --git a/src/azul/terra.py b/src/azul/terra.py index 6c859ab70c..ae2f2d8380 100644 --- a/src/azul/terra.py +++ b/src/azul/terra.py @@ -77,6 +77,7 @@ ) from azul.drs import ( DRSClient, + DRSObject, ) from azul.http import ( LimitedRetryHttpClient, @@ -368,7 +369,7 @@ def _insufficient_access(self, resource: str) -> Exception: return self.credentials_provider.insufficient_access(resource) -class TDRClient(SAMClient): +class TDRClient(SAMClient, DRSClient): """ A client for the Broad Institute's Terra Data Repository aka "Jade". """ @@ -641,8 +642,8 @@ def for_registered_user(cls, authentication: OAuth2) -> Self: else: return self - def drs_client(self) -> DRSClient: - return DRSClient(http_client=self._http_client) + def drs_object(self, drs_url: furl) -> DRSObject: + return DRSObject(url=drs_url, http_client=self._http_client) def get_duos(self, source: TDRSourceRef diff --git a/test/integration_test.py b/test/integration_test.py index a6e7b1fc21..3ee57296de 100644 --- a/test/integration_test.py +++ b/test/integration_test.py @@ -1145,24 +1145,31 @@ def _test_repository_files(self, catalog: CatalogName): self.assertIsNone(file_url, inner_file) self.assertEqual('lungmap', config.catalogs[catalog].atlas, inner_file) - def _test_file_download(self, source: SourceSpec, file: JSON) -> mutable_furl: + def _test_file_download(self, source: SourceSpec, file: JSON) -> mutable_furl | None: file_url = furl(file['azul_url']) # FIXME: Use _check_endpoint() instead # https://github.com/DataBiosphere/azul/issues/7373 self.assertEqual(file_url.path.segments[0], 'repository') file_url.path.segments.insert(0, 'fetch') response = self._get_url_unchecked(GET, file_url) - self.assertEqual(200, response.status) - response = json.loads(response.data) - while response['Status'] != 302: - self.assertEqual(301, response['Status']) + if response.status == 401: + msg = json.loads(response.data)['Message'] + prefix = 'Unexpected response from ' + self.assertEqual(prefix, msg[:len(prefix)]) + self.assertNotIn(str(config.tdr_service_url), msg) + return None + else: + self.assertEqual(200, response.status) + response = json.loads(response.data) + while response['Status'] != 302: + self.assertEqual(301, response['Status']) + self.assertNotIn('Retry-After', response) + response = self._get_url_json(GET, furl(response['Location'])) self.assertNotIn('Retry-After', response) - response = self._get_url_json(GET, furl(response['Location'])) - self.assertNotIn('Retry-After', response) - final_file_url = furl(response['Location']) - response = self._get_url(GET, final_file_url, stream=True) - self._validate_file_response(response, source, file) - return final_file_url + final_file_url = furl(response['Location']) + response = self._get_url(GET, final_file_url, stream=True) + self._validate_file_response(response, source, file) + return final_file_url def _file_ext(self, file: JSON) -> str: # We believe that the file extension is a more reliable indicator than @@ -1209,13 +1216,13 @@ def _test_drs(self, file: JSON ) -> None: repository_plugin = self.azul_client.repository_plugin(catalog) - drs = repository_plugin.drs_client() file_uuid = lookup(file, 'document_id', 'uuid') + drs_uri = f'drs://{config.api_lambda_domain("service")}/{file_uuid}' + drs_object = repository_plugin.drs_object(drs_uri) for access_method in AccessMethod: with self.subTest('drs', catalog=catalog, access_method=AccessMethod.https): log.info('Resolving file %r with DRS using %r', file_uuid, access_method) - drs_uri = f'drs://{config.api_lambda_domain("service")}/{file_uuid}' - access = drs.get_object(drs_uri, access_method=access_method) + access = drs_object.get(access_method) self.assertIsNone(access.headers) if access.method is AccessMethod.https: response = self._get_url(GET, furl(access.url), stream=True) @@ -1795,6 +1802,7 @@ def _delete(): aws.mirror_bucket, '_it', 'file', f'{digest.value}.{digest.type}', ]) actual_url = self._test_file_download(source.spec, file_response) + self.assertIsNotNone(actual_url) actual_url.set(args=None) self.assertEqual(expected_url, actual_url) _delete() diff --git a/test/service/test_repository_files.py b/test/service/test_repository_files.py index ab8eba0640..2421de8151 100644 --- a/test/service/test_repository_files.py +++ b/test/service/test_repository_files.py @@ -43,7 +43,7 @@ from azul.drs import ( Access, AccessMethod, - DRSClient, + DRSObject, ) from azul.http import ( http_client, @@ -168,10 +168,8 @@ def test_repository_files(self, mock_get_cached_sources): 'X-Goog-SignedHeaders': 'host', 'X-Goog-Signature': 'SOMESIGNATURE', }) - with mock.patch.object(DRSClient, - 'get_object', - return_value=Access(method=AccessMethod.https, - url=str(pre_signed_gs))): + access = Access(method=AccessMethod.https, url=str(pre_signed_gs)) + with mock.patch.object(DRSObject, 'get', return_value=access): response = client.request('GET', str(azul_url), redirect=False) self.assertEqual(200 if fetch else 302, response.status) if fetch: