Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion babs/merge.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import math
import os
import os.path as op
import re
import shutil
import stat
import subprocess
import time
import warnings

import datalad.api as dlapi
Expand All @@ -11,6 +14,88 @@
from babs.utils import get_git_show_ref_shasum


def robust_rm_dir(path, max_retries=3, retry_delay=1):
"""
Robustly remove a directory tree, handling filesystem quirks and locked files.

For datalad datasets, this function prioritizes `datalad remove`.
Falls back to shutil.rmtree() with retries if needed.

Parameters
----------
path : str
Path to the directory to remove
max_retries : int
Maximum number of retry attempts for shutil.rmtree()
retry_delay : float
Delay in seconds between retries
"""
if not op.exists(path):
return

# Check if it's a datalad dataset (presence of a `.datalad` directory).
is_datalad_dataset = op.exists(op.join(path, '.datalad'))

# For datalad datasets, try datalad remove first
if is_datalad_dataset:
try:
# Untracked files in merge_ds can block `datalad remove`, so discard untracked.
if op.exists(op.join(path, '.git')):
subprocess.run(
['git', 'clean', '-fdx'],
cwd=path,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
dlapi.remove(path=path, dataset=path, reckless='availability')
# datalad remove might not remove everything, check if path still exists
if not op.exists(path):
return
# If path still exists, fall through to shutil.rmtree() below
except Exception as e:
print(f'Warning: datalad remove failed for {path}: {e}')
print('Falling back to shutil.rmtree()...')

# Fallback: use shutil.rmtree() with error handling and retries
def handle_remove_readonly(func, path, _exc):
"""
Error handler for shutil.rmtree that attempts to fix permission issues.
"""
# Change file to be writable, readable, and executable
os.chmod(path, stat.S_IWRITE | stat.S_IREAD | stat.S_IEXEC)
try:
func(path)
except Exception as e:
# If still fails, try to remove as file
print(f'Warning: Failed to remove {path} after fixing permissions: {e}')
try:
os.remove(path)
except Exception as e2:
print(f'Warning: Failed to remove {path} as file: {e2}')

# Try shutil.rmtree() with retries
for attempt in range(max_retries):
try:
shutil.rmtree(path, onerror=handle_remove_readonly)
return
except OSError as e:
if attempt < max_retries - 1:
print(
f'Warning: Failed to remove {path} (attempt {attempt + 1}/{max_retries}): {e}'
)
time.sleep(retry_delay)
continue
else:
# All retries failed, warn but don't crash since merge was successful
warnings.warn(
f"Failed to remove temporary directory '{path}' after {max_retries} attempts. "
f'Error: {e}. '
'The merge was successful, but you may need to manually remove '
'this directory. You can safely delete it with: rm -rf ' + path,
stacklevel=2,
)


class BABSMerge(BABS):
"""BABSMerge is for merging results and provenance from finished jobs."""

Expand Down Expand Up @@ -279,7 +364,8 @@ def babs_merge(self, chunk_size=1000, trial_run=False):
print('\n`babs merge` was successful!')

# delete the merge_ds folder
shutil.rmtree(merge_ds_path)
print('\nCleaning up merge_ds directory...')
robust_rm_dir(merge_ds_path)

# Delete all the merged branches from the output RIA
for n_chunk, chunk in enumerate(all_chunks):
Expand Down
228 changes: 226 additions & 2 deletions tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Test merge.py error handling and edge cases."""

import os
import shutil
import stat
import subprocess
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import datalad.api as dlapi
import pytest

from babs.merge import BABSMerge
from babs.merge import BABSMerge, robust_rm_dir
from babs.utils import get_git_show_ref_shasum


Expand Down Expand Up @@ -75,3 +79,223 @@ def mock_remote_show(cmd, **kwargs):

with pytest.raises(Exception, match='There is no job branch in output RIA that has results'):
babs_proj.babs_merge()


def test_rm_dir_nonexistent(tmp_path):
"""Test robust_rm_dir when path doesn't exist."""
nonexistent_path = tmp_path / 'nonexistent'
# Should not raise an error
robust_rm_dir(str(nonexistent_path))


def test_rm_dir_regular(tmp_path):
"""Test robust_rm_dir with a regular (non-datalad) directory."""
test_dir = tmp_path / 'test_dir'
test_dir.mkdir()
(test_dir / 'file.txt').write_text('test')

robust_rm_dir(str(test_dir))

assert not test_dir.exists()


def test_rm_dir_datalad_ok(tmp_path, monkeypatch):
"""Test robust_rm_dir with datalad dataset that removes successfully."""
test_dir = tmp_path / 'test_datalad'
test_dir.mkdir()
(test_dir / '.datalad').mkdir()
(test_dir / 'file.txt').write_text('test')

# Mock datalad.remove to succeed
def mock_remove(path=None, dataset=None, **kwargs):
import shutil

shutil.rmtree(dataset or path)

monkeypatch.setattr(dlapi, 'remove', mock_remove)

robust_rm_dir(str(test_dir))

assert not test_dir.exists()


def test_rm_dir_datalad_fail(tmp_path, monkeypatch):
"""Test robust_rm_dir when datalad.remove fails and falls back to shutil.rmtree."""
test_dir = tmp_path / 'test_datalad'
test_dir.mkdir()
(test_dir / '.datalad').mkdir()
(test_dir / 'file.txt').write_text('test')

# Mock datalad.remove to raise an exception
def mock_remove(path=None, dataset=None, **kwargs):
raise Exception('datalad remove failed')

monkeypatch.setattr(dlapi, 'remove', mock_remove)

robust_rm_dir(str(test_dir))

assert not test_dir.exists()


def test_rm_dir_datalad_partial(tmp_path, monkeypatch):
"""Test robust_rm_dir when datalad.remove succeeds but path still exists."""
test_dir = tmp_path / 'test_datalad'
test_dir.mkdir()
(test_dir / '.datalad').mkdir()
(test_dir / 'file.txt').write_text('test')

# Mock datalad.remove to succeed but not remove everything
def mock_remove(path=None, dataset=None, **kwargs):
# Remove some files but leave the directory
(test_dir / 'file.txt').unlink()
# Don't actually remove the directory

monkeypatch.setattr(dlapi, 'remove', mock_remove)

robust_rm_dir(str(test_dir))

# Should fall back to shutil.rmtree and remove everything
assert not test_dir.exists()


def test_rm_dir_git(tmp_path):
"""Test robust_rm_dir with a plain git repo (has .git directory but no .datalad)."""
test_dir = tmp_path / 'test_git'
test_dir.mkdir()
(test_dir / '.git').mkdir()
(test_dir / 'file.txt').write_text('test')

# Should NOT call datalad.remove for a plain git repo
with patch.object(dlapi, 'remove') as mock_remove:
robust_rm_dir(str(test_dir))
mock_remove.assert_not_called()

# Should be removed via shutil.rmtree fallback
assert not test_dir.exists()


def test_rm_dir_permission(tmp_path, monkeypatch):
"""Test robust_rm_dir handling of permission errors."""
test_dir = tmp_path / 'test_readonly'
test_dir.mkdir()
readonly_file = test_dir / 'readonly.txt'
readonly_file.write_text('test')

# Make file readonly
os.chmod(str(readonly_file), stat.S_IREAD)

# Mock shutil.rmtree to simulate a permission error on first attempt,
# then succeed on a retry.
original_rmtree = shutil.rmtree
call_count = {'count': 0}

def mock_rmtree(path, onerror=None):
call_count['count'] += 1
if call_count['count'] == 1:
# First call: trigger onerror, then raise so robust_rm_dir retries.
if onerror:
err = PermissionError(13, 'Permission denied', str(readonly_file))
onerror(os.remove, str(readonly_file), (PermissionError, err, None))
raise OSError('Permission denied')
else:
# Subsequent calls: succeed
original_rmtree(path, onerror=onerror)

monkeypatch.setattr('shutil.rmtree', mock_rmtree)

robust_rm_dir(str(test_dir), max_retries=3, retry_delay=0)

assert not test_dir.exists()
assert call_count['count'] == 2


def test_rm_dir_retry(tmp_path, monkeypatch):
"""Test robust_rm_dir retry logic when removal fails initially."""
test_dir = tmp_path / 'test_retry'
test_dir.mkdir()
(test_dir / 'file.txt').write_text('test')

# Mock shutil.rmtree to fail twice then succeed
original_rmtree = shutil.rmtree
call_count = {'count': 0}

def mock_rmtree(path, onerror=None):
call_count['count'] += 1
if call_count['count'] < 3:
raise OSError(f'Failed attempt {call_count["count"]}')
original_rmtree(path, onerror=onerror)

monkeypatch.setattr('shutil.rmtree', mock_rmtree)

robust_rm_dir(str(test_dir), max_retries=3, retry_delay=0)

assert not test_dir.exists()
assert call_count['count'] == 3


def test_rm_dir_max_retries(tmp_path, monkeypatch):
"""Test robust_rm_dir when max retries are exceeded."""
test_dir = tmp_path / 'test_fail'
test_dir.mkdir()
(test_dir / 'file.txt').write_text('test')

# Mock shutil.rmtree to always fail
def mock_rmtree(path, onerror=None):
raise OSError('Always fails')

monkeypatch.setattr('shutil.rmtree', mock_rmtree)

# Should warn but not crash
with pytest.warns(UserWarning, match='Failed to remove temporary directory'):
robust_rm_dir(str(test_dir), max_retries=2, retry_delay=0)

# Directory should still exist
assert test_dir.exists()


def test_merge_existing(babs_project_sessionlevel, tmp_path, monkeypatch):
"""Test babs_merge when merge_ds already exists."""
babs_proj = BABSMerge(babs_project_sessionlevel)

# Create merge_ds directory
merge_ds_path = tmp_path / 'merge_ds'
merge_ds_path.mkdir()
monkeypatch.setattr(babs_proj, 'project_root', str(tmp_path))

with pytest.raises(Exception, match="Folder 'merge_ds' already exists"):
babs_proj.babs_merge()


def test_merge_no_head(babs_project_sessionlevel, tmp_path, monkeypatch):
"""Test babs_merge when there's no HEAD branch."""
babs_proj = BABSMerge(babs_project_sessionlevel)

monkeypatch.setattr(babs_proj, 'project_root', str(tmp_path))

def set_analysis_id():
babs_proj.analysis_dataset_id = 'test-id'

monkeypatch.setattr(babs_proj, 'wtf_key_info', set_analysis_id)
monkeypatch.setattr(babs_proj, '_get_results_branches', lambda: ['job-123'])

from babs.merge import dlapi

def mock_clone(source, path):
# Create the directory so subsequent git commands can run
os.makedirs(path, exist_ok=True)
return None

monkeypatch.setattr(dlapi, 'clone', mock_clone)

def mock_remote_show(cmd, **kwargs):
if 'remote' in cmd and 'show' in cmd:
result = MagicMock()
result.returncode = 0
result.stdout = b'No HEAD branch found\n' # No HEAD branch
return result
return subprocess.run(cmd, **kwargs)

monkeypatch.setattr('babs.merge.subprocess.run', mock_remote_show)

with pytest.raises(Exception, match='There is no HEAD branch in output RIA!'):
babs_proj.babs_merge()
7 changes: 3 additions & 4 deletions tests/test_update_input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
babs_submit_main,
babs_update_input_data_main,
)
from babs.scheduler import squeue_to_pandas


def manually_add_new_subject_to_input_data(input_data_path: str):
Expand Down Expand Up @@ -152,9 +151,9 @@ def test_babs_update_input_data(
for waitnum in [5, 8, 10, 15, 30, 60, 120]:
time.sleep(waitnum)
print(f'Waiting {waitnum} seconds...')
df = squeue_to_pandas()
print(df)
if df.empty:
currently_running_df = babs_proj.get_currently_running_jobs_df()
print(currently_running_df)
if currently_running_df.empty:
finished = True
break

Expand Down