Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions src/aiida/tools/_dumping/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast

import click

from aiida import orm
from aiida.common import AIIDA_LOGGER
from aiida.common.progress_reporter import get_progress_reporter, set_progress_bar_tqdm
from aiida.tools._dumping.config import GroupDumpConfig, GroupDumpScope, ProfileDumpConfig
from aiida.tools._dumping.mapping import GroupNodeMapping
from aiida.tools._dumping.utils import (
DUMP_PROGRESS_BAR_FORMAT,
REGISTRY_TO_ORM_TYPE,
DumpPaths,
DumpTimes,
Expand Down Expand Up @@ -92,6 +96,7 @@ def get_nodes(
:param exclude_tracked: Whether to exclude nodes already in dump tracker
:return: ProcessingQueue containing the filtered nodes
"""

if group_scope == GroupDumpScope.IN_GROUP and not group:
msg = 'Scope is IN_GROUP but no group object was provided.'
raise ValueError(msg)
Expand All @@ -105,19 +110,25 @@ def get_nodes(
processing_queue = ProcessingQueue()

# Process calculations
logger.report('Querying calculation nodes from database...')
calc_nodes = self._query_single_type(
orm_type=orm.CalculationNode, group_scope=group_scope, group=group, base_filters=base_filters
)
logger.report(f'Retrieved {len(calc_nodes)} calculation nodes.')

if exclude_tracked:
calc_nodes = self._exclude_tracked_nodes(calc_nodes, 'calculations')
if apply_filters:
calc_nodes = self._apply_behavioral_filters(calc_nodes, 'calculations')
processing_queue.calculations = calc_nodes

# Process workflows
logger.report('Querying workflow nodes from database...')
workflow_nodes = self._query_single_type(
orm_type=orm.WorkflowNode, group_scope=group_scope, group=group, base_filters=base_filters
)
logger.report(f'Retrieved {len(workflow_nodes)} workflow nodes.')

if exclude_tracked:
workflow_nodes = self._exclude_tracked_nodes(workflow_nodes, 'workflows')
if apply_filters:
Expand Down Expand Up @@ -190,7 +201,20 @@ def _exclude_tracked_nodes(self, nodes: list[orm.ProcessNode], store_type: str)
if not tracked_uuids:
return nodes

return [node for node in nodes if node.uuid not in tracked_uuids]
return_nodes = []
set_progress_bar_tqdm(bar_format=DUMP_PROGRESS_BAR_FORMAT)

progress_desc = f"{click.style('Report', fg='blue', bold=True)}: Excluding already dumped {store_type}..."
with get_progress_reporter()(desc=progress_desc, total=len(nodes)) as progress:
for node in nodes:
if node.uuid not in tracked_uuids:
return_nodes.append(node)

progress.update()

logger.report(f'Applied exclusion of previously dumped {store_type}.')

return return_nodes

except ValueError as e:
logger.error(f"Error getting registry for '{store_type}': {e}")
Expand All @@ -203,6 +227,7 @@ def _apply_behavioral_filters(self, nodes: list[orm.ProcessNode], store_type: st
:param store_type: Target store (calculations or workflows)
:return: Filtered list of nodes, with top-level and group membership filters applied
"""

if not nodes:
return nodes

Expand All @@ -217,12 +242,20 @@ def _apply_behavioral_filters(self, nodes: list[orm.ProcessNode], store_type: st

# Apply caller filter (keep top-level or explicitly grouped)
filtered_nodes = []
for node in nodes:
is_sub_node = bool(getattr(node, 'caller', None))
is_explicitly_grouped = node.uuid in self.grouped_node_uuids
set_progress_bar_tqdm(bar_format=DUMP_PROGRESS_BAR_FORMAT)

progress_desc = f"{click.style('Report', fg='blue', bold=True)}: Applying filters to {store_type}..."
with get_progress_reporter()(desc=progress_desc, total=len(nodes)) as progress:
for node in nodes:
is_sub_node = bool(getattr(node, 'caller', None))
is_explicitly_grouped = node.uuid in self.grouped_node_uuids

if not is_sub_node or is_explicitly_grouped:
filtered_nodes.append(node)

progress.update()

if not is_sub_node or is_explicitly_grouped:
filtered_nodes.append(node)
logger.report(f'Applied relevant filters to {store_type}.')

return filtered_nodes

Expand Down
26 changes: 21 additions & 5 deletions src/aiida/tools/_dumping/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,20 @@ def __init__(
def _build_mapping_for_target(self) -> GroupNodeMapping:
"""Build the appropriate group-node mapping based on the target entity and config."""
if isinstance(self.dump_target_entity, orm.Group):
# Single group dump - pass it as a single-element list
logger.report(f'Building group-node mapping for single group `{self.dump_target_entity.label}`...')
return GroupNodeMapping.build_from_db(groups=[self.dump_target_entity])

elif isinstance(self.dump_target_entity, Profile):
# Profile dump - depends on config
assert isinstance(self.config, ProfileDumpConfig)

if self.config.all_entries:
# Build mapping for all groups
logger.report('Building group-node mapping for all groups in profile...')
return GroupNodeMapping.build_from_db(groups=None)
elif self.config.groups:
# Build mapping only for specified groups
logger.report(f'Building group-node mapping for {len(self.config.groups)} specified groups...')
return GroupNodeMapping.build_from_db(groups=self.config.groups)
else:
# No groups specified - return empty mapping
return GroupNodeMapping()

else:
Expand All @@ -109,7 +108,7 @@ def _log_dump_start(self) -> None:
elif isinstance(self.dump_target_entity, Profile):
dump_start_report = f'profile `{self.dump_target_entity.name}`'

msg = f'Starting dump of {dump_start_report} in {self.config.dump_mode.name.lower()} mode.'
msg = f'Starting dump of {dump_start_report} in {self.config.dump_mode.name.lower()} mode...'
if self.config.dump_mode != DumpMode.DRY_RUN:
logger.report(msg)

Expand Down Expand Up @@ -226,10 +225,27 @@ def _dump_profile(self) -> None:
return None

self.dump_tracker.set_current_mapping(self.current_mapping)

logger.report('Detecting changes since last dump. This may take a while for large databases...')

logger.report('Detecting node changes...')
node_changes = self.detector._detect_node_changes()
msg = (
f'Detected {len(node_changes.new_or_modified)} new/modified nodes '
f'and {len(node_changes.deleted)} deleted nodes.'
)
logger.report(msg)

logger.report('Detecting group changes...')
group_changes = self.detector._detect_group_changes(
previous_mapping=self.dump_tracker.previous_mapping, current_mapping=self.current_mapping
)
msg = (
f'Detected {len(group_changes.new)} new, {len(group_changes.modified)} modified, '
f'and {len(group_changes.deleted)} deleted groups.'
)
logger.report(msg)

all_changes = DumpChanges(nodes=node_changes, groups=group_changes)

if all_changes.is_empty():
Expand Down
12 changes: 7 additions & 5 deletions src/aiida/tools/_dumping/executors/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union

import click

from aiida import orm
from aiida.common import AIIDA_LOGGER, NotExistent
from aiida.common.progress_reporter import get_progress_reporter, set_progress_bar_tqdm
from aiida.tools._dumping.detect import DumpChangeDetector
from aiida.tools._dumping.tracking import DumpRecord, DumpTracker
from aiida.tools._dumping.utils import DumpChanges, DumpPaths, ProcessingQueue
from aiida.tools._dumping.utils import DUMP_PROGRESS_BAR_FORMAT, DumpChanges, DumpPaths, ProcessingQueue

if TYPE_CHECKING:
from aiida.tools._dumping.config import GroupDumpConfig, ProfileDumpConfig
Expand Down Expand Up @@ -129,7 +131,7 @@ def _dump_nodes(
:param group_context: _description_, defaults to None
:param current_dump_root_for_nodes: _description_, defaults to None
"""
set_progress_bar_tqdm()
set_progress_bar_tqdm(bar_format=DUMP_PROGRESS_BAR_FORMAT)
nodes_to_dump = []
nodes_to_dump.extend(processing_queue.calculations)
nodes_to_dump.extend(processing_queue.workflows)
Expand All @@ -139,7 +141,6 @@ def _dump_nodes(
desc = f'Dumping {len(nodes_to_dump)} nodes'
if group_context:
desc += f" for group '{group_context.label}'"
logger.report(desc)

if current_dump_root_for_nodes is None:
# This is a fallback, the caller should ideally always provide the explicit root.
Expand All @@ -149,7 +150,8 @@ def _dump_nodes(
current_dump_root_for_nodes = self.dump_paths.get_path_for_ungrouped_nodes()
logger.warning(f'current_dump_root_for_nodes was None, derived as: {current_dump_root_for_nodes}')

with get_progress_reporter()(desc=desc, total=len(nodes_to_dump)) as progress:
progress_desc = f"{click.style('Report', fg='blue', bold=True)}: {desc}"
with get_progress_reporter()(desc=progress_desc, total=len(nodes_to_dump)) as progress:
for node in nodes_to_dump:
# Determine the specific, absolute path for this node's dump directory
node_specific_dump_path = self.dump_paths.get_path_for_node(
Expand Down Expand Up @@ -211,7 +213,7 @@ def _handle_group_changes(self, group_changes: GroupChanges) -> None:

:param group_changes: Populated ``GroupChanges`` object
"""
logger.report('Processing group changes.')
logger.report('Processing group changes...')

# Handle Deleted Groups. Actual directory deletion handled by DeletionExecutor, only logging done here.
if group_changes.deleted:
Expand Down
13 changes: 12 additions & 1 deletion src/aiida/tools/_dumping/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from typing import Dict, List, Optional, Set, Union, cast

from aiida import orm
from aiida.common.log import AIIDA_LOGGER
from aiida.tools._dumping.utils import GroupChanges, GroupInfo, GroupModificationInfo, NodeMembershipChange

LOGGER = AIIDA_LOGGER.getChild('tools._dumping.mapping')


@dataclass
class GroupNodeMapping:
Expand Down Expand Up @@ -76,6 +79,7 @@ def build_from_db(cls, groups: Optional[Union[List[orm.Group], List[str], List[i
If None, build mapping for all groups.
:return: Populated ``GroupNodeMapping`` instance
"""

mapping = cls()

# Query all groups and their nodes, or just the specific groups
Expand All @@ -89,15 +93,22 @@ def build_from_db(cls, groups: Optional[Union[List[orm.Group], List[str], List[i
else:
group_uuids = [orm.load_group(g).uuid for g in groups]
qb.append(orm.Group, tag='group', project=['uuid'], filters={'uuid': {'in': group_uuids}})
LOGGER.report(f'Querying node memberships for {len(group_uuids)} group(s)...')
else:
# Query all groups
qb.append(orm.Group, tag='group', project=['uuid'])
LOGGER.report('Querying node memberships for all groups in profile...')

qb.append(orm.Node, with_group='group', project=['uuid'])

for group_uuid, node_uuid in qb.all():
LOGGER.report('Retrieving group-node relationships from database...')
results = qb.all()
LOGGER.report(f'Processing {len(results)} group-node relationships...')

for group_uuid, node_uuid in results:
mapping._add_node_to_group(group_uuid, node_uuid)

LOGGER.report('Completed group-node mapping.')
return mapping

def diff(self, other: 'GroupNodeMapping') -> GroupChanges:
Expand Down
8 changes: 6 additions & 2 deletions src/aiida/tools/_dumping/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
from __future__ import annotations

import os
import sys
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Literal, Optional, Set, Type, Union

try:
if sys.version_info >= (3, 11):
# typing.assert_never available since 3.11
from typing import assert_never
except ImportError:
else:
from typing_extensions import assert_never

from aiida import orm
Expand All @@ -29,6 +30,8 @@

RegistryNameType = Literal['calculations', 'workflows', 'groups']

# Progress bar format for dump operations - wider description field to avoid truncation
DUMP_PROGRESS_BAR_FORMAT = '{desc:60.60}{percentage:6.1f}%|{bar}| {n_fmt}/{total_fmt}'

REGISTRY_TO_ORM_TYPE: dict[str, Type[Union[orm.CalculationNode, orm.WorkflowNode, orm.Group]]] = {
'calculations': orm.CalculationNode,
Expand All @@ -47,6 +50,7 @@
}

__all__ = (
'DUMP_PROGRESS_BAR_FORMAT',
'ORM_TYPE_TO_REGISTRY',
'REGISTRY_TO_ORM_TYPE',
'DumpMode',
Expand Down