Skip to content

Commit fc44fd8

Browse files
authored
Merge pull request #978 from RS-PYTHON/fix/staging-on-cluster
Fix/staging on cluster
2 parents f321573 + 4733573 commit fc44fd8

File tree

5 files changed

+92
-34
lines changed

5 files changed

+92
-34
lines changed

services/common/rs_server_common/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,13 @@
1414

1515
"""Main package of commons of rs-server services."""
1616

17+
import os
18+
1719
# Set automatically by running `poetry dynamic-versioning`
1820
__version__ = "0.0.0"
21+
22+
# Some kind of workaround for boto3 to avoid checksum being added inside
23+
# the file contents uploaded to the s3 bucket e.g. x-amz-checksum-crc32:xxx
24+
# See: https://github.com/boto/boto3/issues/4435
25+
os.environ["AWS_REQUEST_CHECKSUM_CALCULATION"] = "when_required"
26+
os.environ["AWS_RESPONSE_CHECKSUM_VALIDATION"] = "when_required"

services/staging/rs_server_staging/main.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
init_rs_server_config_yaml,
3232
)
3333
from rs_server_common.db import Base
34-
from rs_server_common.settings import env_bool
34+
from rs_server_common.settings import LOCAL_MODE
3535
from rs_server_common.utils import opentelemetry
3636
from rs_server_common.utils.logging import Logging
3737
from rs_server_common.utils.utils2 import filelock
@@ -207,8 +207,9 @@ async def app_lifespan(fastapi_app: FastAPI): # pylint: disable=too-many-statem
207207
# Create jobs table
208208
process_manager = init_db()
209209

210+
# In local mode, if the gateway is not defined, create a dask LocalCluster
210211
cluster = None
211-
if env_bool("RSPY_LOCAL_MODE", default=False):
212+
if LOCAL_MODE and ("RSPY_DASK_STAGING_CLUSTER_NAME" not in os.environ):
212213
# Create the LocalCluster only in local mode
213214
cluster = LocalCluster()
214215
logger.info("Local Dask cluster created at startup.")
@@ -222,7 +223,7 @@ async def app_lifespan(fastapi_app: FastAPI): # pylint: disable=too-many-statem
222223

223224
# Shutdown logic (cleanup)
224225
logger.info("Shutting down the application...")
225-
if env_bool("RSPY_LOCAL_MODE", default=False) and cluster:
226+
if LOCAL_MODE and cluster:
226227
cluster.close()
227228
logger.info("Local Dask cluster shut down.")
228229

@@ -325,6 +326,15 @@ async def get_specific_job_result_endpoint(job_id: str = Path(..., title="The ID
325326
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=f"Job with ID {job_id} not found") from error
326327

327328

329+
if LOCAL_MODE:
330+
331+
@router.post("/staging/dask/auth")
332+
async def dask_auth(local_dask_username: str, local_dask_password: str):
333+
"""Set dask cluster authentication, only in local mode."""
334+
os.environ["LOCAL_DASK_USERNAME"] = local_dask_username
335+
os.environ["LOCAL_DASK_PASSWORD"] = local_dask_password
336+
337+
328338
# Configure OpenTelemetry
329339
opentelemetry.init_traces(app, "rs.server.staging")
330340

services/staging/rs_server_staging/processors.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
import requests
2626
from dask.distributed import CancelledError, Client, LocalCluster, as_completed
27-
from dask_gateway import Gateway, JupyterHubAuth
27+
from dask_gateway import Gateway
28+
from dask_gateway.auth import BasicAuth, JupyterHubAuth
2829
from fastapi import HTTPException
2930
from pygeoapi.process.base import BaseProcessor
3031
from pygeoapi.process.manager.postgresql import PostgreSQLManager
@@ -36,6 +37,7 @@
3637
load_external_auth_config_by_domain,
3738
)
3839
from rs_server_common.s3_storage_handler.s3_storage_handler import S3StorageHandler
40+
from rs_server_common.settings import LOCAL_MODE
3941
from rs_server_common.utils.logging import Logging
4042
from starlette.datastructures import Headers
4143
from starlette.requests import Request
@@ -615,7 +617,7 @@ def manage_dask_tasks_results(self, client: Client, catalog_collection: str):
615617
self.log_job_execution(JobStatus.successful, 100, "Finished")
616618
self.logger.info("Tasks monitoring finished")
617619

618-
def dask_cluster_connect(self) -> Client:
620+
def dask_cluster_connect(self) -> Client: # pylint: disable=too-many-branches,too-many-statements
619621
"""Connects a dask cluster scheduler
620622
Establishes a connection to a Dask cluster, either in a local environment or via a Dask Gateway in
621623
a Kubernetes cluster. This method checks if the cluster is already created (for local mode) or connects
@@ -672,42 +674,62 @@ def dask_cluster_connect(self) -> Client:
672674
# If self.cluster is already initialized, it means the application is running in local mode, and
673675
# the cluster was created when the application started.
674676
if not self.cluster:
675-
# in kubernetes cluster mode, we have to connect to the gateway and get the list of the clusters
677+
# Connect to the gateway and get the list of the clusters
676678
try:
677679
# get the name of the cluster
678680
cluster_name = os.environ["RSPY_DASK_STAGING_CLUSTER_NAME"]
679-
# check the auth type, only jupyterhub type supported for now
680-
auth_type = os.environ["DASK_GATEWAY__AUTH__TYPE"]
681-
# Handle JupyterHub authentication
682-
if auth_type == "jupyterhub":
683-
gateway_auth = JupyterHubAuth(api_token=os.environ["JUPYTERHUB_API_TOKEN"])
681+
682+
# In local mode, authenticate to the dask cluster with username/password
683+
if LOCAL_MODE:
684+
gateway_auth = BasicAuth(
685+
os.environ["LOCAL_DASK_USERNAME"],
686+
os.environ["LOCAL_DASK_PASSWORD"],
687+
)
688+
689+
# Cluster mode
684690
else:
685-
self.logger.error(f"Unsupported authentication type: {auth_type}")
686-
raise RuntimeError(f"Unsupported authentication type: {auth_type}")
691+
# check the auth type, only jupyterhub type supported for now
692+
auth_type = os.environ["DASK_GATEWAY__AUTH__TYPE"]
693+
# Handle JupyterHub authentication
694+
if auth_type == "jupyterhub":
695+
gateway_auth = JupyterHubAuth(api_token=os.environ["JUPYTERHUB_API_TOKEN"])
696+
else:
697+
self.logger.error(f"Unsupported authentication type: {auth_type}")
698+
raise RuntimeError(f"Unsupported authentication type: {auth_type}")
699+
687700
gateway = Gateway(
688701
address=os.environ["DASK_GATEWAY__ADDRESS"],
689702
auth=gateway_auth,
690703
)
691-
clusters = gateway.list_clusters()
692-
self.logger.debug(f"The list of clusters: {clusters}")
693704

694-
# Get the identifier of the cluster whose name is equal to the cluster_name variable
705+
# Sort the clusters by newest first
706+
clusters = sorted(gateway.list_clusters(), key=lambda cluster: cluster.start_time, reverse=True)
707+
self.logger.debug(f"Cluster list for gateway {os.environ['DASK_GATEWAY__ADDRESS']!r}: {clusters}")
708+
709+
# In local mode, get the first cluster from the gateway.
710+
cluster_id = None
711+
if LOCAL_MODE:
712+
if clusters:
713+
cluster_id = clusters[0].name
714+
715+
# In cluster mode, get the identifier of the cluster whose name is equal to the cluster_name variable.
695716
# Protection for the case when this cluster does not exit
696-
cluster_id = next(
697-
(
698-
cluster.name
699-
for cluster in clusters
700-
if isinstance(cluster.options, dict) and cluster.options.get("cluster_name") == cluster_name
701-
),
702-
None,
703-
)
717+
else:
718+
cluster_id = next(
719+
(
720+
cluster.name
721+
for cluster in clusters
722+
if isinstance(cluster.options, dict) and cluster.options.get("cluster_name") == cluster_name
723+
),
724+
None,
725+
)
704726

705727
if not cluster_id:
706-
raise IndexError(f"No dask cluster named '{cluster_name}' was found.")
728+
raise IndexError(f"Dask cluster with 'cluster_name'={cluster_name!r} was not found.")
707729

708730
self.cluster = gateway.connect(cluster_id)
709-
710731
self.logger.info(f"Successfully connected to the {cluster_name} dask cluster")
732+
711733
except KeyError as e:
712734
self.logger.exception(
713735
"Failed to retrieve the required connection details for "
@@ -725,6 +747,22 @@ def dask_cluster_connect(self) -> Client:
725747
# create the client as well
726748
client = Client(self.cluster)
727749

750+
# Forward logging from dask workers to the caller
751+
client.forward_logging()
752+
753+
def set_dask_env(host_env: dict):
754+
"""Pass environment variables to the dask workers."""
755+
for name in ["S3_ACCESSKEY", "S3_SECRETKEY", "S3_ENDPOINT", "S3_REGION"]:
756+
os.environ[name] = host_env[name]
757+
758+
# Some kind of workaround for boto3 to avoid checksum being added inside
759+
# the file contents uploaded to the s3 bucket e.g. x-amz-checksum-crc32:xxx
760+
# See: https://github.com/boto/boto3/issues/4435
761+
os.environ["AWS_REQUEST_CHECKSUM_CALCULATION"] = "when_required"
762+
os.environ["AWS_RESPONSE_CHECKSUM_VALIDATION"] = "when_required"
763+
764+
client.run(set_dask_env, os.environ)
765+
728766
# This is a temporary fix for the dask cluster settings which does not create a scheduler by default
729767
# This code should be removed as soon as this is fixed in the kubernetes cluster
730768
try:

services/staging/tests/test_rspy_processor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,8 @@ def test_dask_cluster_connect(self, mocker, staging_instance: Staging, cluster_o
715715
"RSPY_DASK_STAGING_CLUSTER_NAME": cluster_options["cluster_name"],
716716
},
717717
)
718+
# Mock the cluster mode
719+
mocker.patch("rs_server_staging.processors.LOCAL_MODE", new=False, autospec=False)
718720
# Mock the logger
719721
mock_logger = mocker.patch.object(staging_instance, "logger")
720722
staging_instance.cluster = None
@@ -749,7 +751,9 @@ def test_dask_cluster_connect(self, mocker, staging_instance: Staging, cluster_o
749751
mock_client.assert_called_once_with(staging_instance.cluster)
750752

751753
# Ensure logging was called as expected
752-
mock_logger.debug.assert_any_call(f"The list of clusters: {mock_list_clusters.return_value}")
754+
mock_logger.debug.assert_any_call(
755+
f"Cluster list for gateway 'gateway-address': {mock_list_clusters.return_value}",
756+
)
753757
mock_logger.info.assert_any_call("Number of running workers: 2")
754758
mock_logger.debug.assert_any_call(
755759
f"Dask Client: {client} | Cluster dashboard: {mock_connect.return_value.dashboard_link}",
@@ -768,6 +772,8 @@ def test_dask_cluster_connect_failure_no_cluster_name(self, mocker, staging_inst
768772
"RSPY_DASK_STAGING_CLUSTER_NAME": non_existent_cluster,
769773
},
770774
)
775+
# Mock the cluster mode
776+
mocker.patch("rs_server_staging.processors.LOCAL_MODE", new=False, autospec=False)
771777
# Mock the logger
772778
mock_logger = mocker.patch.object(staging_instance, "logger")
773779
staging_instance.cluster = None
@@ -791,7 +797,8 @@ def test_dask_cluster_connect_failure_no_cluster_name(self, mocker, staging_inst
791797
staging_instance.dask_cluster_connect()
792798
# Ensure logging was called as expected
793799
mock_logger.exception.assert_any_call(
794-
"Failed to find the specified dask cluster: " f"No dask cluster named '{non_existent_cluster}' was found.",
800+
"Failed to find the specified dask cluster: "
801+
f"Dask cluster with 'cluster_name'={non_existent_cluster!r} was not found.",
795802
)
796803

797804
def test_dask_cluster_connect_failure_no_envs(self, mocker, staging_instance: Staging):

services/staging/tests/test_staging.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,7 @@ async def test_app_lifespan_gateway_error(
526526
"""Test app_lifespan when there is an error in connecting to the Dask Gateway."""
527527

528528
# Mock environment variables to simulate gateway mode
529-
mocker.patch.dict(
530-
os.environ,
531-
{
532-
"RSPY_LOCAL_MODE": "0",
533-
},
534-
)
529+
mocker.patch("rs_server_staging.main.LOCAL_MODE", new=False, autospec=False)
535530

536531
# Mock FastAPI app
537532
mock_app = FastAPI()

0 commit comments

Comments
 (0)