Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions tests/unit/test_globus_refresh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""
# Run all tests
pytest test_globus_refresh.py -v

# Run with coverage
pytest test_globus_refresh.py --cov=zstash.globus --cov-report=html

# Run with output capture for debugging
pytest test_globus_refresh.py -v -s
"""

import json
from unittest.mock import Mock, mock_open, patch

import pytest

from zstash.globus import globus_block_wait, globus_transfer
from zstash.globus_utils import load_tokens

# Core functionality tests ####################################################


# Verifies that globus_transfer() calls endpoint_autoactivate for both endpoints
def test_globus_transfer_refreshes_tokens():
"""Test that globus_transfer calls endpoint_autoactivate"""
with patch("zstash.globus.transfer_client") as mock_client, patch(
"zstash.globus.local_endpoint", "local-uuid"
), patch("zstash.globus.remote_endpoint", "remote-uuid"), patch(
"zstash.globus.task_id", None
), patch(
"zstash.globus.transfer_data", None
):

mock_client.endpoint_autoactivate = Mock()
mock_client.operation_ls = Mock(return_value=[])
mock_client.submit_transfer = Mock(return_value={"task_id": "test-123"})

# Call the function
globus_transfer("remote-ep", "/path", "file.tar", "put", False)

# Verify autoactivate was called for both endpoints
assert mock_client.endpoint_autoactivate.call_count >= 2
calls = mock_client.endpoint_autoactivate.call_args_list

# Check it was called with correct parameters
assert any("local-uuid" in str(call) for call in calls)
assert any("remote-uuid" in str(call) for call in calls)
assert any("if_expires_in=86400" in str(call) for call in calls)


# Confirms periodic refresh during long waits
def test_globus_block_wait_refreshes_periodically():
"""Test that globus_block_wait refreshes tokens on each retry"""
with patch("zstash.globus.transfer_client") as mock_client, patch(
"zstash.globus.local_endpoint", "local-uuid"
):

mock_client.endpoint_autoactivate = Mock()
mock_client.task_wait = Mock(return_value=True)
mock_client.get_task = Mock(return_value={"status": "SUCCEEDED"})

# Call with max_retries=3
globus_block_wait("task-123", 1, 1, 3)

# Should call autoactivate at least once per retry
assert mock_client.endpoint_autoactivate.call_count >= 1


# Validates expiration detection logic
def test_load_tokens_detects_expiration(caplog):
"""Test that load_tokens detects soon-to-expire tokens"""
import time as time_module

# Create a token file with expiration in 30 minutes
current_time = 1000000
expires_at = current_time + 1800 # 30 minutes from now

tokens = {
"transfer.api.globus.org": {
"access_token": "fake_token",
"refresh_token": "fake_refresh",
"expires_at": expires_at,
}
}

with patch.object(time_module, "time", return_value=current_time), patch(
"builtins.open", mock_open(read_data=json.dumps(tokens))
), patch("os.path.exists", return_value=True):

with caplog.at_level("INFO"):
result = load_tokens()

# Check that warning was logged
assert "expiring soon" in caplog.text
assert result == tokens


# Library compatibility test ##################################################


def test_token_refresh_with_real_client():
"""
Integration test that uses real Globus SDK but mocks the endpoints.
This verifies the RefreshTokenAuthorizer actually works without needing
real credentials.
"""
from globus_sdk import NativeAppAuthClient, RefreshTokenAuthorizer, TransferClient

from zstash.globus_utils import ZSTASH_CLIENT_ID

# Create a mock authorizer that simulates token refresh
auth_client = NativeAppAuthClient(ZSTASH_CLIENT_ID)

# Create a mock refresh token (won't actually work, but tests the pattern)
mock_refresh_token = "mock_refresh_token_xyz"

try:
# This will fail with invalid token, but we're testing the mechanism exists
authorizer = RefreshTokenAuthorizer(
refresh_token=mock_refresh_token, auth_client=auth_client
)

# Verify the authorizer was created successfully
assert authorizer is not None
assert hasattr(authorizer, "access_token")

# Verify we can create a transfer client with it
transfer_client = TransferClient(authorizer=authorizer)
assert transfer_client is not None

except Exception as e:
# We expect this to fail with auth errors, but not with missing attributes
assert "RefreshTokenAuthorizer" not in str(e)


# Edge case tests #############################################################


# Ensures no issues with many rapid refresh calls
def test_multiple_rapid_refreshes():
"""Test that calling refresh many times doesn't break"""
with patch("zstash.globus.transfer_client") as mock_client:
mock_client.endpoint_autoactivate = Mock()

# Simulate what happens during a long transfer with many wait iterations
for _ in range(100):
mock_client.endpoint_autoactivate("test-endpoint", if_expires_in=86400)

# Should have been called 100 times without error
assert mock_client.endpoint_autoactivate.call_count == 100


# End-to-end test with mocked transfer
def test_small_transfer_with_refresh_enabled():
"""
Functional test: Transfer a small file and verify refresh calls were made.
"""
with patch("zstash.globus.transfer_client") as mock_client, patch(
"zstash.globus.local_endpoint", "local-uuid"
), patch("zstash.globus.remote_endpoint", "remote-uuid"), patch(
"zstash.globus.task_id", None
), patch(
"zstash.globus.transfer_data", None
):

# Set up mock to track calls
mock_client.endpoint_autoactivate = Mock()
mock_client.submit_transfer = Mock(return_value={"task_id": "test-123"})
mock_client.task_wait = Mock(return_value=True)
mock_client.get_task = Mock(return_value={"status": "SUCCEEDED"})

# Run a transfer
globus_transfer("endpoint", "/path", "small.tar", "put", non_blocking=False)

# Verify refresh was called
assert mock_client.endpoint_autoactivate.called


# Parametrized tests ##########################################################


# Tests blocking PUT mode
# Tests non-blocking PUT mode
@pytest.mark.parametrize(
"transfer_type,non_blocking",
[
("put", False),
("put", True),
],
)
def test_globus_transfer_refreshes_in_all_modes(transfer_type, non_blocking):
"""Test that token refresh works for all transfer types"""
with patch("zstash.globus.transfer_client") as mock_client, patch(
"zstash.globus.local_endpoint", "local-uuid"
), patch("zstash.globus.remote_endpoint", "remote-uuid"), patch(
"zstash.globus.task_id", None
), patch(
"zstash.globus.transfer_data", None
), patch(
"zstash.globus.archive_directory_listing", [{"name": "file.tar"}]
):

mock_client.endpoint_autoactivate = Mock()
mock_client.operation_ls = Mock(return_value=[{"name": "file.tar"}])
# Need to return a complete task dict to avoid KeyError
mock_client.submit_transfer = Mock(
return_value={
"task_id": "test-123",
"source_endpoint_id": "src-uuid",
"destination_endpoint_id": "dst-uuid",
"label": "test transfer",
}
)
mock_client.task_wait = Mock(return_value=True)
mock_client.get_task = Mock(
return_value={
"status": "SUCCEEDED",
"source_endpoint_id": "src-uuid",
"destination_endpoint_id": "dst-uuid",
"label": "test transfer",
}
)

globus_transfer("remote-ep", "/path", "file.tar", transfer_type, non_blocking)

# Verify refresh was called
assert mock_client.endpoint_autoactivate.called


# Fixture example #############################################################


@pytest.fixture
def mock_globus_client():
"""Fixture to set up a mock Globus client"""
with patch("zstash.globus.transfer_client") as mock_client, patch(
"zstash.globus.local_endpoint", "local-uuid"
), patch("zstash.globus.remote_endpoint", "remote-uuid"), patch(
"zstash.globus.task_id", None
), patch(
"zstash.globus.transfer_data", None
):

mock_client.endpoint_autoactivate = Mock()
mock_client.operation_ls = Mock(return_value=[])
mock_client.submit_transfer = Mock(
return_value={
"task_id": "test-123",
"source_endpoint_id": "src-uuid",
"destination_endpoint_id": "dst-uuid",
"label": "test transfer",
}
)
mock_client.task_wait = Mock(return_value=True)
mock_client.get_task = Mock(
return_value={
"status": "SUCCEEDED",
"source_endpoint_id": "src-uuid",
"destination_endpoint_id": "dst-uuid",
"label": "test transfer",
}
)

yield mock_client


# Demonstrates reusable fixture pattern
def test_with_fixture(mock_globus_client):
"""Test using the fixture"""
globus_transfer("remote-ep", "/path", "file.tar", "put", False)
assert mock_globus_client.endpoint_autoactivate.call_count >= 2
49 changes: 46 additions & 3 deletions zstash/globus.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, print_function

import sys
import time
from typing import List, Optional

from globus_sdk import TransferAPIError, TransferClient, TransferData
Expand All @@ -26,6 +27,30 @@
task_id = None
archive_directory_listing: IterableTransferResponse = None

DEBUG_LONG_TRANSFER: bool = False # Set to true if testing token expiration handling


def _debug_sleep_to_expire_token(context: str, retry_count: int = 0):
"""
FOR DEBUGGING ONLY: Sleep to simulate token expiration during long operations.

Args:
context: Description of where this is being called (e.g., "blocking", "non-blocking")
retry_count: Current retry count (only sleep on first iteration)
"""
if DEBUG_LONG_TRANSFER and retry_count == 0:
transfer_duration_mock_hours = 49
logger.info(
f"{ts_utc()}: TESTING ({context}): Sleeping for {transfer_duration_mock_hours} hours to let access token expire"
)
time.sleep(transfer_duration_mock_hours * 3600)
logger.info(
f"{ts_utc()}: TESTING ({context}): Woke up after {transfer_duration_mock_hours} hours. "
"Access token expired, RefreshTokenAuthorizer should automatically refresh on next API call."
)
return True # Indicates sleep happened
return False # No sleep


def globus_activate(hpss: str):
"""
Expand Down Expand Up @@ -83,6 +108,17 @@ def globus_transfer( # noqa: C901
if not transfer_client:
sys.exit(1)

# Force token refresh before long operation
try:
# Make a simple API call to trigger refresh if needed
transfer_client.endpoint_autoactivate(local_endpoint, if_expires_in=86400)
transfer_client.endpoint_autoactivate(remote_endpoint, if_expires_in=86400)

# FOR DEBUGGING: Test non-blocking mode token refresh
_debug_sleep_to_expire_token("non-blocking")
except Exception as e:
logger.warning(f"Token refresh check: {e}")

if transfer_type == "get":
if not archive_directory_listing:
archive_directory_listing = transfer_client.operation_ls(
Expand Down Expand Up @@ -195,16 +231,18 @@ def globus_transfer( # noqa: C901
def globus_block_wait(
task_id: str, wait_timeout: int, polling_interval: int, max_retries: int
):

# poll every "polling_interval" seconds to speed up small transfers. Report every 2 hours, stop waiting aftert 5*2 = 10 hours
logger.info(
f"{ts_utc()}: BLOCKING START: invoking task_wait for task_id = {task_id}"
)
task_status = "UNKNOWN"
retry_count = 0

while retry_count < max_retries:
try:
# Wait for the task to complete
# Refresh token before each wait attempt
transfer_client.endpoint_autoactivate(local_endpoint, if_expires_in=86400)
transfer_client.endpoint_autoactivate(remote_endpoint, if_expires_in=86400)

logger.info(
f"{ts_utc()}: on task_wait try {retry_count + 1} out of {max_retries}"
)
Expand All @@ -217,6 +255,11 @@ def globus_block_wait(
else:
curr_task = transfer_client.get_task(task_id)
task_status = curr_task["status"]

# FOR DEBUGGING: Test blocking mode token refresh
if _debug_sleep_to_expire_token("blocking", retry_count):
continue # Force another iteration to test refresh

if task_status == "SUCCEEDED":
break
finally:
Expand Down
16 changes: 15 additions & 1 deletion zstash/globus_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import socket
import sys
import time
from typing import Dict, List, Optional

from globus_sdk import (
Expand Down Expand Up @@ -185,7 +186,20 @@ def load_tokens():
if os.path.exists(TOKEN_FILE):
try:
with open(TOKEN_FILE, "r") as f:
return json.load(f)
tokens = json.load(f)

# Check if access token is expired or expiring soon
transfer_token = tokens.get("transfer.api.globus.org", {})
expires_at = transfer_token.get("expires_at")

if expires_at:
# Refresh if expiring within 1 hour
if time.time() > (expires_at - 3600):
logger.info(
"Access token expired or expiring soon - will need refresh"
)

return tokens
except (json.JSONDecodeError, IOError):
return {}
return {}
Expand Down
Loading