Skip to content
Merged
11 changes: 11 additions & 0 deletions programs/match_template/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Match template program example scripts and configs

This directory contains example files for configuring the match template program, and example Python scripts for running the match template program.
See the online documentation for comprehensive information on configuring the match template program.

## Files

- `match_template_example_config.yaml` - An example configuration YAML file for constructing a `MatchTemplateManager` object.
- `run_match_template.py` - The default Python script for running the match template program. This supports muli-GPU systems (configure GPUs using the YAML file).
- `run_distributed_match_template.py` - A Python script for running match template on large-scale distributed systems (multi-node clusters). _Use the default script unless you're running on more than one machine_.
- `distributed_match_template.slurm` - An example SLURM script for running the distributed match template (_launching from a workload manager is required_).
78 changes: 78 additions & 0 deletions programs/match_template/distributed_match_template.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/bin/bash

# ***
# *** This is an example SLURM job script for launching a distributed
# *** match_template job using torchrun over multiple nodes in a cluster.
# *** There are many points at which you will need to modify the script
# *** to fit onto your specific cluster environment.
# ***
# *** NOTE: If you are just trying to saturate GPU resources and have
# *** enough micrographs to process (and no time pressure for
# *** results), then it's advisable to just launch multiple
# *** single-node jobs instead of distributed jobs.
# ***

#SBATCH --job-name=distributed-match-template-%j
#SBATCH --nodes=4 # EDIT: how many nodes allocated
#SBATCH --ntasks-per-node=1 # crucial! - only 1 task per node
#SBATCH --cpus-per-task=8 # EDIT: match number of GPUs per node
#SBATCH --gres=gpu:8 # EDIT: number & type of GPUs per node
#SBATCH --time=2:00:00 # EDIT: desired runtime (hh:mm:ss)
#SBATCH --partition=<part> # EDIT: your partition
#SBATCH --qos=<qos> # EDIT: your qos
#SBATCH --account=<acct> # EDIT: your account name
#SBATCH --output=%x-%j.out


echo "START TIME: $(date)"


# EDIT: Necessary commands to set up your environment *before*
# running the program (e.g. loading modules, conda envs, etc.)
SETUP="ml anaconda3 && \
source ~/.bashrc && \
conda activate leopard-em-dev && \
"

# EDIT: How many GPUs per node (should match what was requested in --gres)
GPUS_PER_NODE=4

# EDIT: Define your program an its argument
PROGRAM="/global/home/users/matthewgiammar/Leopard-EM/programs/match_template/run_distributed_match_template.py"
# OR if CLI arguments are required:
# PROGRAM="programs/match_template/run_distributed_match_template.py --arg1 val1 --arg2 val2"



# Verbose output for debugging purposes (can comment out if not needed)
set -x
srun hostname # each allocated node prints the hostname

# Some parameters to extract necessary information from SLURM
allocated_nodes=$(scontrol show hostname $SLURM_JOB_NODELIST)
nodes=${allocated_nodes//$'\n'/ } # replace newlines with spaces
nodes_array=($nodes)
head_node=${nodes_array[0]}
echo Head Node: $head_node
echo Node List: $nodes
export LOGLEVEL=INFO

# The command for torchrun to launch the distributed job
# NOTE: --rdzv_id requires an open port, so using a random number.
# But there may be restrictions on allowed ports on your cluster...
LAUNCHER="torchrun \
--nproc_per_node=$GPUS_PER_NODE \
--nnodes=$SLURM_JOB_NUM_NODES \
--rdzv_id=$RANDOM \
--rdzv_backend=c10d \
--rdzv_endpoint=$head_node:29500 \
"
CMD="$SETUP $LAUNCHER $PROGRAM"


echo "Running command:"
echo $CMD
echo "-------------------"
srun /bin/bash -c "$CMD"

echo "END TIME: $(date)"
4 changes: 2 additions & 2 deletions programs/match_template/match_template_example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ preprocessing_filters:
bandpass_filter:
enabled: false
falloff: null
high_freq_cutoff: null # Both high/low in terms of Nyquist frequency
low_freq_cutoff: null # e.g. low-pass to 3 Å @ 1.06 Å/px would be 0.353
high_freq_cutoff: null # Both in terms of Nyquist frequency (e.g. low-pass to 3 Å @ 1.06 Å/px would be high_freq_cutoff=0.353)
low_freq_cutoff: null
computational_config:
gpu_ids:
- 0
Expand Down
77 changes: 77 additions & 0 deletions programs/match_template/run_distributed_match_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Run the match_template program in a distributed, multi-node environment.

NOTE: This script needs to be launched using `torchrun` and within a distributed
environment where multiple nodes can communicate with each other. See the online
documentation and example scripts for more information on running distributed multi
node match_template.

NOTE: The 'gpu_ids' field in the YAML config is ignored when running in distributed
mode. Each process is assigned to a single GPU based on its local rank.
"""

import os
import time

import torch.distributed as dist

from leopard_em.pydantic_models.managers import MatchTemplateManager

#######################################
### Editable parameters for program ###
#######################################

# NOTE: You can also use `click` to pass argument to this script from command line
YAML_CONFIG_PATH = "/global/home/users/matthewgiammar/Leopard-EM/benchmark/tmp/test_match_template_xenon_216_000_0.0_DWS_config.yaml"
DATAFRAME_OUTPUT_PATH = "out.csv"
ORIENTATION_BATCH_SIZE = 20


def initialize_distributed() -> tuple[int, int, int]:
"""Initialize the distributed environment.

Returns
-------
(world_size, global_rank, local_rank)
"""
dist.init_process_group(backend="nccl")
world_size = dist.get_world_size()
rank = dist.get_rank()
local_rank = os.environ.get("LOCAL_RANK", None)

# Raise error if LOCAL_RANK is not set. This *should* be handled by torchrun, but...
# It is up to the user to rectify this issue on their system.
if local_rank is None:
raise RuntimeError("LOCAL_RANK environment variable unset!.")

local_rank = int(local_rank)

return world_size, rank, local_rank


def main() -> None:
"""Main function for the distributed match_template program.

Each process is associated with a single GPU, and we front-load the distributed
initialization and GPU assignment in this script. This allows both the manager
object and the backend match_template code to remain relatively simple.
"""
world_size, rank, local_rank = initialize_distributed()
print(f"RANK={rank}: Initialized {world_size} processes (local_rank={local_rank}).")

# Do not pre-load mrc files, unless zeroth rank. Data will be broadcast later.
mt_manager = MatchTemplateManager.from_yaml(
YAML_CONFIG_PATH, preload_mrc_files=bool(rank == 0)
)
mt_manager.run_match_template_distributed(
world_size=world_size,
rank=rank,
local_rank=local_rank,
orientation_batch_size=ORIENTATION_BATCH_SIZE,
do_result_export=(rank == 0), # Only save results from rank 0
)


if __name__ == "__main__":
start_time = time.time()
main()
print(f"Total time: {time.time() - start_time:.1f} seconds.")
64 changes: 41 additions & 23 deletions src/leopard_em/backend/core_match_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from functools import partial
from multiprocessing import set_start_method
from typing import Union
from typing import Any, Union

import roma
import torch
Expand All @@ -16,7 +16,10 @@
from leopard_em.backend.cross_correlation import (
do_streamed_orientation_cross_correlate,
)
from leopard_em.backend.distributed import SharedWorkIndexQueue, run_multiprocess_jobs
from leopard_em.backend.distributed import (
SharedWorkIndexQueue,
run_multiprocess_jobs,
)
from leopard_em.backend.process_results import (
aggregate_distributed_results,
decode_global_search_index,
Expand Down Expand Up @@ -173,8 +176,8 @@ def core_match_template(
Gets multiplied with the ctf filters to create a filter stack applied to each
orientation projection.
euler_angles : torch.Tensor
Euler angles (in 'ZYZ' convention) to search over. Has shape
(num_orientations, 3).
Euler angles (in 'ZYZ' convention & in units of degrees) to search over. Has
shape (num_orientations, 3).
defocus_values : torch.Tensor
What defoucs values correspond with the CTF filters, in units of Angstroms. Has
shape (num_defocus,).
Expand Down Expand Up @@ -270,8 +273,6 @@ def core_match_template(
prefetch_size=10,
num_processes=len(device),
)

# Progress tracking for global search and correlations per-device
global_pbar, device_pbars = setup_progress_tracking(
index_queue=index_queue,
unit_scale=defocus_values.shape[0] * pixel_values.shape[0],
Expand Down Expand Up @@ -302,7 +303,7 @@ def core_match_template(
kwargs_per_device.append(kwargs)

result_dict = run_multiprocess_jobs(
target=_core_match_template_single_gpu,
target=_core_match_template_multiprocess_wrapper,
kwargs_list=kwargs_per_device,
post_start_callback=progress_callback,
)
Expand Down Expand Up @@ -348,8 +349,7 @@ def core_match_template(
# pylint: disable=too-many-arguments
# pylint: disable=too-many-positional-arguments
def _core_match_template_single_gpu(
result_dict: dict,
device_id: int,
rank: int,
index_queue: SharedWorkIndexQueue,
image_dft: torch.Tensor,
template_dft: torch.Tensor,
Expand All @@ -360,18 +360,13 @@ def _core_match_template_single_gpu(
orientation_batch_size: int,
num_cuda_streams: int,
device: torch.device,
) -> None:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Single-GPU call for template matching.

NOTE: The result_dict is a shared dictionary between processes and updated in-place
with this processes's results under the 'device_id' key.

Parameters
----------
result_dict : dict
Dictionary to store the results in.
device_id : int
ID of the device which computation is running on. Results will be stored
rank : int
Rank of the device which computation is running on. Results will be stored
in the dictionary with this key.
index_queue : SharedWorkIndexQueue
Torch multiprocessing object for retrieving the next batch of orientations to
Expand Down Expand Up @@ -404,7 +399,14 @@ def _core_match_template_single_gpu(

Returns
-------
None
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing the following tensors:
- mip: Maximum intensity projection of the cross-correlation values across
orientation and defocus search space.
- best_global_index: Global index of the best match for each pixel.
- correlation_sum: Sum of cross-correlation values for each pixel.
- correlation_squared_sum: Sum of squared cross-correlation values for
each pixel.
"""
image_shape_real = (image_dft.shape[0], image_dft.shape[1] * 2 - 2) # adj. for RFFT

Expand All @@ -419,8 +421,6 @@ def _core_match_template_single_gpu(
template_dft = template_dft.to(device)
euler_angles = euler_angles.to(device)
projective_filters = projective_filters.to(device)
# defocus_values = defocus_values.to(device)
# pixel_values = pixel_values.to(device)

num_orientations = euler_angles.shape[0]
num_defocus = defocus_values.shape[0]
Expand Down Expand Up @@ -465,7 +465,7 @@ def _core_match_template_single_gpu(
raise RuntimeError("Exiting due to error in another process.")

try:
indices = index_queue.get_next_indices(process_id=device_id)
indices = index_queue.get_next_indices(process_id=rank)
if indices is None:
break

Expand Down Expand Up @@ -511,7 +511,7 @@ def _core_match_template_single_gpu(
)
except Exception as e:
index_queue.set_error_flag()
print(f"Error occurred in process {device_id}: {e}")
print(f"Error occurred in process {rank}: {e}")
raise e

# Synchronization barrier post-computation
Expand All @@ -520,6 +520,24 @@ def _core_match_template_single_gpu(

torch.cuda.synchronize(device)

return mip, best_global_index, correlation_sum, correlation_squared_sum


def _core_match_template_multiprocess_wrapper(
result_dict: dict, rank: int, **kwargs: dict[str, Any]
) -> None:
"""Wrapper around _core_match_template_single_gpu for use with multiprocessing.

This function places results into a shared dictionary for retrieval by the main
core_match_template function. These results are stored under the 'rank' key, and
they need to exist on the CPU as numpy arrays for the shared dictionary.

See the _core_match_template_single_gpu function for parameter descriptions.
"""
mip, best_global_index, correlation_sum, correlation_squared_sum = (
_core_match_template_single_gpu(rank, **kwargs) # type: ignore[arg-type]
)

# NOTE: Need to send all tensors back to the CPU as numpy arrays for the shared
# process dictionary. This is a workaround for now
result = {
Expand All @@ -531,7 +549,7 @@ def _core_match_template_single_gpu(

# Place the results in the shared multi-process manager dictionary so accessible
# by the main process.
result_dict[device_id] = result
result_dict[rank] = result

# Final cleanup to release all tensors from this GPU
torch.cuda.empty_cache()
Expand Down
Loading