diff --git a/tests/unit/test_globus_refresh.py b/tests/unit/test_globus_refresh.py new file mode 100644 index 00000000..c4fed0d0 --- /dev/null +++ b/tests/unit/test_globus_refresh.py @@ -0,0 +1,271 @@ +""" +# Run all tests +pytest test_globus_refresh.py -v + +# Run with coverage +pytest test_globus_refresh.py --cov=zstash.globus --cov-report=html + +# Run with output capture for debugging +pytest test_globus_refresh.py -v -s +""" + +import json +from unittest.mock import Mock, mock_open, patch + +import pytest + +from zstash.globus import globus_block_wait, globus_transfer +from zstash.globus_utils import load_tokens + +# Core functionality tests #################################################### + + +# Verifies that globus_transfer() calls endpoint_autoactivate for both endpoints +def test_globus_transfer_refreshes_tokens(): + """Test that globus_transfer calls endpoint_autoactivate""" + with patch("zstash.globus.transfer_client") as mock_client, patch( + "zstash.globus.local_endpoint", "local-uuid" + ), patch("zstash.globus.remote_endpoint", "remote-uuid"), patch( + "zstash.globus.task_id", None + ), patch( + "zstash.globus.transfer_data", None + ): + + mock_client.endpoint_autoactivate = Mock() + mock_client.operation_ls = Mock(return_value=[]) + mock_client.submit_transfer = Mock(return_value={"task_id": "test-123"}) + + # Call the function + globus_transfer("remote-ep", "/path", "file.tar", "put", False) + + # Verify autoactivate was called for both endpoints + assert mock_client.endpoint_autoactivate.call_count >= 2 + calls = mock_client.endpoint_autoactivate.call_args_list + + # Check it was called with correct parameters + assert any("local-uuid" in str(call) for call in calls) + assert any("remote-uuid" in str(call) for call in calls) + assert any("if_expires_in=86400" in str(call) for call in calls) + + +# Confirms periodic refresh during long waits +def test_globus_block_wait_refreshes_periodically(): + """Test that globus_block_wait refreshes tokens on each retry""" + with patch("zstash.globus.transfer_client") as mock_client, patch( + "zstash.globus.local_endpoint", "local-uuid" + ): + + mock_client.endpoint_autoactivate = Mock() + mock_client.task_wait = Mock(return_value=True) + mock_client.get_task = Mock(return_value={"status": "SUCCEEDED"}) + + # Call with max_retries=3 + globus_block_wait("task-123", 1, 1, 3) + + # Should call autoactivate at least once per retry + assert mock_client.endpoint_autoactivate.call_count >= 1 + + +# Validates expiration detection logic +def test_load_tokens_detects_expiration(caplog): + """Test that load_tokens detects soon-to-expire tokens""" + import time as time_module + + # Create a token file with expiration in 30 minutes + current_time = 1000000 + expires_at = current_time + 1800 # 30 minutes from now + + tokens = { + "transfer.api.globus.org": { + "access_token": "fake_token", + "refresh_token": "fake_refresh", + "expires_at": expires_at, + } + } + + with patch.object(time_module, "time", return_value=current_time), patch( + "builtins.open", mock_open(read_data=json.dumps(tokens)) + ), patch("os.path.exists", return_value=True): + + with caplog.at_level("INFO"): + result = load_tokens() + + # Check that warning was logged + assert "expiring soon" in caplog.text + assert result == tokens + + +# Library compatibility test ################################################## + + +def test_token_refresh_with_real_client(): + """ + Integration test that uses real Globus SDK but mocks the endpoints. + This verifies the RefreshTokenAuthorizer actually works without needing + real credentials. + """ + from globus_sdk import NativeAppAuthClient, RefreshTokenAuthorizer, TransferClient + + from zstash.globus_utils import ZSTASH_CLIENT_ID + + # Create a mock authorizer that simulates token refresh + auth_client = NativeAppAuthClient(ZSTASH_CLIENT_ID) + + # Create a mock refresh token (won't actually work, but tests the pattern) + mock_refresh_token = "mock_refresh_token_xyz" + + try: + # This will fail with invalid token, but we're testing the mechanism exists + authorizer = RefreshTokenAuthorizer( + refresh_token=mock_refresh_token, auth_client=auth_client + ) + + # Verify the authorizer was created successfully + assert authorizer is not None + assert hasattr(authorizer, "access_token") + + # Verify we can create a transfer client with it + transfer_client = TransferClient(authorizer=authorizer) + assert transfer_client is not None + + except Exception as e: + # We expect this to fail with auth errors, but not with missing attributes + assert "RefreshTokenAuthorizer" not in str(e) + + +# Edge case tests ############################################################# + + +# Ensures no issues with many rapid refresh calls +def test_multiple_rapid_refreshes(): + """Test that calling refresh many times doesn't break""" + with patch("zstash.globus.transfer_client") as mock_client: + mock_client.endpoint_autoactivate = Mock() + + # Simulate what happens during a long transfer with many wait iterations + for _ in range(100): + mock_client.endpoint_autoactivate("test-endpoint", if_expires_in=86400) + + # Should have been called 100 times without error + assert mock_client.endpoint_autoactivate.call_count == 100 + + +# End-to-end test with mocked transfer +def test_small_transfer_with_refresh_enabled(): + """ + Functional test: Transfer a small file and verify refresh calls were made. + """ + with patch("zstash.globus.transfer_client") as mock_client, patch( + "zstash.globus.local_endpoint", "local-uuid" + ), patch("zstash.globus.remote_endpoint", "remote-uuid"), patch( + "zstash.globus.task_id", None + ), patch( + "zstash.globus.transfer_data", None + ): + + # Set up mock to track calls + mock_client.endpoint_autoactivate = Mock() + mock_client.submit_transfer = Mock(return_value={"task_id": "test-123"}) + mock_client.task_wait = Mock(return_value=True) + mock_client.get_task = Mock(return_value={"status": "SUCCEEDED"}) + + # Run a transfer + globus_transfer("endpoint", "/path", "small.tar", "put", non_blocking=False) + + # Verify refresh was called + assert mock_client.endpoint_autoactivate.called + + +# Parametrized tests ########################################################## + + +# Tests blocking PUT mode +# Tests non-blocking PUT mode +@pytest.mark.parametrize( + "transfer_type,non_blocking", + [ + ("put", False), + ("put", True), + ], +) +def test_globus_transfer_refreshes_in_all_modes(transfer_type, non_blocking): + """Test that token refresh works for all transfer types""" + with patch("zstash.globus.transfer_client") as mock_client, patch( + "zstash.globus.local_endpoint", "local-uuid" + ), patch("zstash.globus.remote_endpoint", "remote-uuid"), patch( + "zstash.globus.task_id", None + ), patch( + "zstash.globus.transfer_data", None + ), patch( + "zstash.globus.archive_directory_listing", [{"name": "file.tar"}] + ): + + mock_client.endpoint_autoactivate = Mock() + mock_client.operation_ls = Mock(return_value=[{"name": "file.tar"}]) + # Need to return a complete task dict to avoid KeyError + mock_client.submit_transfer = Mock( + return_value={ + "task_id": "test-123", + "source_endpoint_id": "src-uuid", + "destination_endpoint_id": "dst-uuid", + "label": "test transfer", + } + ) + mock_client.task_wait = Mock(return_value=True) + mock_client.get_task = Mock( + return_value={ + "status": "SUCCEEDED", + "source_endpoint_id": "src-uuid", + "destination_endpoint_id": "dst-uuid", + "label": "test transfer", + } + ) + + globus_transfer("remote-ep", "/path", "file.tar", transfer_type, non_blocking) + + # Verify refresh was called + assert mock_client.endpoint_autoactivate.called + + +# Fixture example ############################################################# + + +@pytest.fixture +def mock_globus_client(): + """Fixture to set up a mock Globus client""" + with patch("zstash.globus.transfer_client") as mock_client, patch( + "zstash.globus.local_endpoint", "local-uuid" + ), patch("zstash.globus.remote_endpoint", "remote-uuid"), patch( + "zstash.globus.task_id", None + ), patch( + "zstash.globus.transfer_data", None + ): + + mock_client.endpoint_autoactivate = Mock() + mock_client.operation_ls = Mock(return_value=[]) + mock_client.submit_transfer = Mock( + return_value={ + "task_id": "test-123", + "source_endpoint_id": "src-uuid", + "destination_endpoint_id": "dst-uuid", + "label": "test transfer", + } + ) + mock_client.task_wait = Mock(return_value=True) + mock_client.get_task = Mock( + return_value={ + "status": "SUCCEEDED", + "source_endpoint_id": "src-uuid", + "destination_endpoint_id": "dst-uuid", + "label": "test transfer", + } + ) + + yield mock_client + + +# Demonstrates reusable fixture pattern +def test_with_fixture(mock_globus_client): + """Test using the fixture""" + globus_transfer("remote-ep", "/path", "file.tar", "put", False) + assert mock_globus_client.endpoint_autoactivate.call_count >= 2 diff --git a/zstash/globus.py b/zstash/globus.py index 2cacad5f..b16f2b3f 100644 --- a/zstash/globus.py +++ b/zstash/globus.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, print_function import sys +import time from typing import List, Optional from globus_sdk import TransferAPIError, TransferClient, TransferData @@ -26,6 +27,30 @@ task_id = None archive_directory_listing: IterableTransferResponse = None +DEBUG_LONG_TRANSFER: bool = False # Set to true if testing token expiration handling + + +def _debug_sleep_to_expire_token(context: str, retry_count: int = 0): + """ + FOR DEBUGGING ONLY: Sleep to simulate token expiration during long operations. + + Args: + context: Description of where this is being called (e.g., "blocking", "non-blocking") + retry_count: Current retry count (only sleep on first iteration) + """ + if DEBUG_LONG_TRANSFER and retry_count == 0: + transfer_duration_mock_hours = 49 + logger.info( + f"{ts_utc()}: TESTING ({context}): Sleeping for {transfer_duration_mock_hours} hours to let access token expire" + ) + time.sleep(transfer_duration_mock_hours * 3600) + logger.info( + f"{ts_utc()}: TESTING ({context}): Woke up after {transfer_duration_mock_hours} hours. " + "Access token expired, RefreshTokenAuthorizer should automatically refresh on next API call." + ) + return True # Indicates sleep happened + return False # No sleep + def globus_activate(hpss: str): """ @@ -83,6 +108,17 @@ def globus_transfer( # noqa: C901 if not transfer_client: sys.exit(1) + # Force token refresh before long operation + try: + # Make a simple API call to trigger refresh if needed + transfer_client.endpoint_autoactivate(local_endpoint, if_expires_in=86400) + transfer_client.endpoint_autoactivate(remote_endpoint, if_expires_in=86400) + + # FOR DEBUGGING: Test non-blocking mode token refresh + _debug_sleep_to_expire_token("non-blocking") + except Exception as e: + logger.warning(f"Token refresh check: {e}") + if transfer_type == "get": if not archive_directory_listing: archive_directory_listing = transfer_client.operation_ls( @@ -195,16 +231,18 @@ def globus_transfer( # noqa: C901 def globus_block_wait( task_id: str, wait_timeout: int, polling_interval: int, max_retries: int ): - - # poll every "polling_interval" seconds to speed up small transfers. Report every 2 hours, stop waiting aftert 5*2 = 10 hours logger.info( f"{ts_utc()}: BLOCKING START: invoking task_wait for task_id = {task_id}" ) task_status = "UNKNOWN" retry_count = 0 + while retry_count < max_retries: try: - # Wait for the task to complete + # Refresh token before each wait attempt + transfer_client.endpoint_autoactivate(local_endpoint, if_expires_in=86400) + transfer_client.endpoint_autoactivate(remote_endpoint, if_expires_in=86400) + logger.info( f"{ts_utc()}: on task_wait try {retry_count + 1} out of {max_retries}" ) @@ -217,6 +255,11 @@ def globus_block_wait( else: curr_task = transfer_client.get_task(task_id) task_status = curr_task["status"] + + # FOR DEBUGGING: Test blocking mode token refresh + if _debug_sleep_to_expire_token("blocking", retry_count): + continue # Force another iteration to test refresh + if task_status == "SUCCEEDED": break finally: diff --git a/zstash/globus_utils.py b/zstash/globus_utils.py index e5346f69..ee1c0ec7 100644 --- a/zstash/globus_utils.py +++ b/zstash/globus_utils.py @@ -7,6 +7,7 @@ import re import socket import sys +import time from typing import Dict, List, Optional from globus_sdk import ( @@ -185,7 +186,20 @@ def load_tokens(): if os.path.exists(TOKEN_FILE): try: with open(TOKEN_FILE, "r") as f: - return json.load(f) + tokens = json.load(f) + + # Check if access token is expired or expiring soon + transfer_token = tokens.get("transfer.api.globus.org", {}) + expires_at = transfer_token.get("expires_at") + + if expires_at: + # Refresh if expiring within 1 hour + if time.time() > (expires_at - 3600): + logger.info( + "Access token expired or expiring soon - will need refresh" + ) + + return tokens except (json.JSONDecodeError, IOError): return {} return {}