|
1 | | -# Copyright (c) Microsoft Corporation. |
2 | | -# Licensed under the MIT License. |
3 | | - |
4 | | -from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl |
5 | | -from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager |
6 | | -from azure.core.credentials import TokenCredential |
7 | | -import grpc |
8 | | - |
9 | | -class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): |
10 | | - """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, |
11 | | - StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an |
12 | | - interceptor to add additional headers to all calls as needed.""" |
13 | | - |
14 | | - def __init__(self, token_credential: TokenCredential, taskhub_name: str): |
15 | | - self._metadata = [("taskhub", taskhub_name)] |
16 | | - super().__init__(self._metadata) |
17 | | - |
18 | | - if token_credential is not None: |
19 | | - self._token_credential = token_credential |
20 | | - self._token_manager = AccessTokenManager(token_credential=self._token_credential) |
21 | | - token = self._token_manager.get_access_token() |
22 | | - self._metadata.append(("authorization", token)) |
23 | | - |
24 | | - def _intercept_call( |
25 | | - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: |
26 | | - """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC |
27 | | - call details.""" |
28 | | - # Refresh the auth token if it is present and needed |
29 | | - if self._metadata is not None: |
30 | | - for i, (key, _) in enumerate(self._metadata): |
31 | | - if key.lower() == "authorization": # Ensure case-insensitive comparison |
32 | | - new_token = self._token_manager.get_access_token() # Get the new token |
33 | | - self._metadata[i] = ("authorization", new_token) # Update the token |
34 | | - |
35 | | - return super()._intercept_call(client_call_details) |
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import grpc |
| 5 | +from azure.core.credentials import TokenCredential |
| 6 | + |
| 7 | +from durabletask.azuremanaged.internal.access_token_manager import \ |
| 8 | + AccessTokenManager |
| 9 | +from durabletask.internal.grpc_interceptor import ( |
| 10 | + DefaultClientInterceptorImpl, _ClientCallDetails) |
| 11 | + |
| 12 | + |
| 13 | +class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): |
| 14 | + """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, |
| 15 | + StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an |
| 16 | + interceptor to add additional headers to all calls as needed.""" |
| 17 | + |
| 18 | + def __init__(self, token_credential: TokenCredential, taskhub_name: str): |
| 19 | + self._metadata = [("taskhub", taskhub_name)] |
| 20 | + super().__init__(self._metadata) |
| 21 | + |
| 22 | + if token_credential is not None: |
| 23 | + self._token_credential = token_credential |
| 24 | + self._token_manager = AccessTokenManager(token_credential=self._token_credential) |
| 25 | + access_token = self._token_manager.get_access_token() |
| 26 | + if access_token is not None: |
| 27 | + self._metadata.append(("authorization", f"Bearer {access_token.token}")) |
| 28 | + |
| 29 | + def _intercept_call( |
| 30 | + self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: |
| 31 | + """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC |
| 32 | + call details.""" |
| 33 | + # Refresh the auth token if it is present and needed |
| 34 | + if self._metadata is not None: |
| 35 | + for i, (key, _) in enumerate(self._metadata): |
| 36 | + if key.lower() == "authorization": # Ensure case-insensitive comparison |
| 37 | + new_token = self._token_manager.get_access_token() # Get the new token |
| 38 | + if new_token is not None: |
| 39 | + self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token |
| 40 | + |
| 41 | + return super()._intercept_call(client_call_details) |
0 commit comments