Skip to content

Commit 99f62d7

Browse files
committed
Moving some files around to remove dependencies
Signed-off-by: Ryan Lettieri <[email protected]>
1 parent ed733ea commit 99f62d7

File tree

14 files changed

+297
-228
lines changed

14 files changed

+297
-228
lines changed

durabletask-azuremanaged/__init__.py

Whitespace-only changes.

durabletask-azuremanaged/durabletask/azuremanaged/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,59 @@
1-
# Copyright (c) Microsoft Corporation.
2-
# Licensed under the MIT License.
3-
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
4-
from datetime import datetime, timedelta, timezone
5-
from typing import Optional
6-
import durabletask.internal.shared as shared
7-
8-
# By default, when there's 10minutes left before the token expires, refresh the token
9-
class AccessTokenManager:
10-
def __init__(self, refresh_buffer: int = 600, metadata: Optional[list[tuple[str, str]]] = None):
11-
self.scope = "https://durabletask.io/.default"
12-
self.refresh_buffer = refresh_buffer
13-
self._use_managed_identity = False
14-
self._metadata = metadata
15-
self._client_id = None
16-
self._logger = shared.get_logger("token_manager")
17-
18-
if metadata: # Ensure metadata is not None
19-
for key, value in metadata:
20-
if key == "use_managed_identity":
21-
self._use_managed_identity = value.lower() == "true" # Properly convert string to bool
22-
elif key == "client_id":
23-
self._client_id = value # Directly assign string
24-
25-
# Choose the appropriate credential based on use_managed_identity
26-
if self._use_managed_identity:
27-
if not self._client_id:
28-
self._logger.debug("Using System Assigned Managed Identity for authentication.")
29-
self.credential = ManagedIdentityCredential()
30-
else:
31-
self._logger.debug("Using User Assigned Managed Identity for authentication.")
32-
self.credential = ManagedIdentityCredential(client_id=self._client_id)
33-
else:
34-
self.credential = DefaultAzureCredential()
35-
self._logger.debug("Using Default Azure Credentials for authentication.")
36-
37-
self.token = None
38-
self.expiry_time = None
39-
40-
def get_access_token(self) -> str:
41-
if self.token is None or self.is_token_expired():
42-
self.refresh_token()
43-
return self.token
44-
45-
# Checks if the token is expired, or if it will expire in the next "refresh_buffer" seconds.
46-
# For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
47-
# We will grab a new token when there're 30minutes left on the lifespan of the token
48-
def is_token_expired(self) -> bool:
49-
if self.expiry_time is None:
50-
return True
51-
return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_buffer))
52-
53-
def refresh_token(self):
54-
new_token = self.credential.get_token(self.scope)
55-
self.token = f"Bearer {new_token.token}"
56-
57-
# Convert UNIX timestamp to timezone-aware datetime
58-
self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc)
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
4+
from datetime import datetime, timedelta, timezone
5+
from typing import Optional
6+
import durabletask.internal.shared as shared
7+
8+
# By default, when there's 10minutes left before the token expires, refresh the token
9+
class AccessTokenManager:
10+
def __init__(self, refresh_buffer: int = 600, metadata: Optional[list[tuple[str, str]]] = None):
11+
self.scope = "https://durabletask.io/.default"
12+
self.refresh_buffer = refresh_buffer
13+
self._use_managed_identity = False
14+
self._metadata = metadata
15+
self._client_id = None
16+
self._logger = shared.get_logger("token_manager")
17+
18+
if metadata: # Ensure metadata is not None
19+
for key, value in metadata:
20+
if key == "use_managed_identity":
21+
self._use_managed_identity = value.lower() == "true" # Properly convert string to bool
22+
elif key == "client_id":
23+
self._client_id = value # Directly assign string
24+
25+
# Choose the appropriate credential based on use_managed_identity
26+
if self._use_managed_identity:
27+
if not self._client_id:
28+
self._logger.debug("Using System Assigned Managed Identity for authentication.")
29+
self.credential = ManagedIdentityCredential()
30+
else:
31+
self._logger.debug("Using User Assigned Managed Identity for authentication.")
32+
self.credential = ManagedIdentityCredential(client_id=self._client_id)
33+
else:
34+
self.credential = DefaultAzureCredential()
35+
self._logger.debug("Using Default Azure Credentials for authentication.")
36+
37+
self.token = None
38+
self.expiry_time = None
39+
40+
def get_access_token(self) -> str:
41+
if self.token is None or self.is_token_expired():
42+
self.refresh_token()
43+
return self.token
44+
45+
# Checks if the token is expired, or if it will expire in the next "refresh_buffer" seconds.
46+
# For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
47+
# We will grab a new token when there're 30minutes left on the lifespan of the token
48+
def is_token_expired(self) -> bool:
49+
if self.expiry_time is None:
50+
return True
51+
return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_buffer))
52+
53+
def refresh_token(self):
54+
new_token = self.credential.get_token(self.scope)
55+
self.token = f"Bearer {new_token.token}"
56+
57+
# Convert UNIX timestamp to timezone-aware datetime
58+
self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc)
5959
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
1-
# Copyright (c) Microsoft Corporation.
2-
# Licensed under the MIT License.
3-
4-
from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl
5-
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
6-
7-
import grpc
8-
9-
class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
10-
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
11-
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
12-
interceptor to add additional headers to all calls as needed."""
13-
14-
def __init__(self, metadata: list[tuple[str, str]]):
15-
super().__init__(metadata)
16-
self._token_manager = AccessTokenManager(metadata=self._metadata)
17-
18-
def _intercept_call(
19-
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
20-
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
21-
call details."""
22-
# Refresh the auth token if it is present and needed
23-
if self._metadata is not None:
24-
for i, (key, _) in enumerate(self._metadata):
25-
if key.lower() == "authorization": # Ensure case-insensitive comparison
26-
new_token = self._token_manager.get_access_token() # Get the new token
27-
self._metadata[i] = ("authorization", new_token) # Update the token
28-
29-
return super()._intercept_call(client_call_details)
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl
5+
from durabletask.azuremanaged.access_token_manager import AccessTokenManager
6+
7+
import grpc
8+
9+
class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
10+
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
11+
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
12+
interceptor to add additional headers to all calls as needed."""
13+
14+
def __init__(self, metadata: list[tuple[str, str]]):
15+
super().__init__(metadata)
16+
self._token_manager = AccessTokenManager(metadata=self._metadata)
17+
18+
def _intercept_call(
19+
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
20+
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
21+
call details."""
22+
# Refresh the auth token if it is present and needed
23+
if self._metadata is not None:
24+
for i, (key, _) in enumerate(self._metadata):
25+
if key.lower() == "authorization": # Ensure case-insensitive comparison
26+
new_token = self._token_manager.get_access_token() # Get the new token
27+
self._metadata[i] = ("authorization", new_token) # Update the token
28+
29+
return super()._intercept_call(client_call_details)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import threading
2+
import time
3+
import requests # You could use grpc or another library depending on your setup
4+
5+
class KeepAliveService:
6+
def __init__(self, interval: int = 60, endpoint: str = "https://sdktest1-fgcac9hja3f8.northcentralus.durabletask.io"):
7+
self.interval = interval # Time interval in seconds
8+
self.endpoint = endpoint # The endpoint for sending no-op requests
9+
self._keep_alive_thread = threading.Thread(target=self._send_noop_periodically)
10+
self._keep_alive_thread.daemon = True # Makes sure it ends when the main program ends
11+
self._keep_alive_thread.start()
12+
13+
def _send_noop_periodically(self):
14+
while True:
15+
try:
16+
# Send a simple GET or POST request to a "ping" or no-op endpoint
17+
response = requests.get(self.endpoint) # Replace with the appropriate method
18+
if response.status_code == 200:
19+
print("No-op request sent successfully.")
20+
else:
21+
print(f"No-op failed with status code {response.status_code}")
22+
except Exception as e:
23+
print(f"Error sending no-op: {e}")
24+
25+
time.sleep(self.interval) # Wait before sending another no-op
26+
27+
# Example Usage
28+
keep_alive_service = KeepAliveService(interval=60, endpoint="https://sdktest1-fgcac9hja3f8.northcentralus.durabletask.io")
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,64 @@
1-
# Copyright (c) Microsoft Corporation.
2-
# Licensed under the MIT License.
3-
4-
from typing import Optional
5-
from durabletask.client import TaskHubGrpcClient
6-
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
7-
from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
8-
9-
# Client class used for Durable Task Scheduler (DTS)
10-
class DurableTaskSchedulerClient(TaskHubGrpcClient):
11-
def __init__(self,
12-
host_address: str,
13-
secure_channel: bool,
14-
metadata: Optional[list[tuple[str, str]]] = None,
15-
use_managed_identity: Optional[bool] = False,
16-
client_id: Optional[str] = None,
17-
taskhub: str = None,
18-
**kwargs):
19-
20-
# Ensure metadata is a list
21-
metadata = metadata or []
22-
self._metadata = metadata.copy() # Use a copy to avoid modifying original
23-
24-
# Append DurableTask-specific metadata
25-
self._metadata.append(("taskhub", taskhub or "default-taskhub"))
26-
self._metadata.append(("dts", "True"))
27-
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
28-
self._metadata.append(("client_id", str(client_id or "None")))
29-
30-
self._access_token_manager = AccessTokenManager(metadata=self._metadata)
31-
self.__update_metadata_with_token()
32-
self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
33-
34-
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
35-
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
36-
super().__init__(
37-
host_address=host_address,
38-
secure_channel=secure_channel,
39-
metadata=None,
40-
interceptors=self._interceptors,
41-
**kwargs
42-
)
43-
44-
def __update_metadata_with_token(self):
45-
"""
46-
Add or update the `authorization` key in the metadata with the current access token.
47-
"""
48-
token = self._access_token_manager.get_access_token()
49-
50-
# Ensure that self._metadata is initialized
51-
if self._metadata is None:
52-
self._metadata = [] # Initialize it if it's still None
53-
54-
# Check if "authorization" already exists in the metadata
55-
updated = False
56-
for i, (key, _) in enumerate(self._metadata):
57-
if key == "authorization":
58-
self._metadata[i] = ("authorization", token)
59-
updated = True
60-
break
61-
62-
# If not updated, add a new entry
63-
if not updated:
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from typing import Optional
5+
from durabletask.client import TaskHubGrpcClient
6+
from durabletask.azuremanaged.access_token_manager import AccessTokenManager
7+
from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
8+
9+
# Client class used for Durable Task Scheduler (DTS)
10+
class DurableTaskSchedulerClient(TaskHubGrpcClient):
11+
def __init__(self,
12+
host_address: str,
13+
secure_channel: bool,
14+
metadata: Optional[list[tuple[str, str]]] = None,
15+
use_managed_identity: Optional[bool] = False,
16+
client_id: Optional[str] = None,
17+
taskhub: str = None,
18+
**kwargs):
19+
20+
# Ensure metadata is a list
21+
metadata = metadata or []
22+
self._metadata = metadata.copy() # Use a copy to avoid modifying original
23+
24+
# Append DurableTask-specific metadata
25+
self._metadata.append(("taskhub", taskhub or "default-taskhub"))
26+
self._metadata.append(("dts", "True"))
27+
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
28+
self._metadata.append(("client_id", str(client_id or "None")))
29+
30+
self._access_token_manager = AccessTokenManager(metadata=self._metadata)
31+
self.__update_metadata_with_token()
32+
self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
33+
34+
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
35+
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
36+
super().__init__(
37+
host_address=host_address,
38+
secure_channel=secure_channel,
39+
metadata=None,
40+
interceptors=self._interceptors,
41+
**kwargs
42+
)
43+
44+
def __update_metadata_with_token(self):
45+
"""
46+
Add or update the `authorization` key in the metadata with the current access token.
47+
"""
48+
token = self._access_token_manager.get_access_token()
49+
50+
# Ensure that self._metadata is initialized
51+
if self._metadata is None:
52+
self._metadata = [] # Initialize it if it's still None
53+
54+
# Check if "authorization" already exists in the metadata
55+
updated = False
56+
for i, (key, _) in enumerate(self._metadata):
57+
if key == "authorization":
58+
self._metadata[i] = ("authorization", token)
59+
updated = True
60+
break
61+
62+
# If not updated, add a new entry
63+
if not updated:
6464
self._metadata.append(("authorization", token))

0 commit comments

Comments
 (0)