Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,19 @@ jobs:

steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
- name: Set up Python 3.12
uses: actions/setup-python@v3
with:
python-version: "3.12"
run: uv python install 3.12
- name: Install dependencies
run: uv sync
- name: Lint with ruff
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install .
pip install -r requirements-pytest.txt
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
uv run ruff check .
uv run ruff format --check .
- name: Type check with ty
continue-on-error: true
run: uv run ty check
- name: Test with pytest
run: |
pytest
uv run pytest
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ venv.bak/
.dmypy.json
dmypy.json

# ruff
.ruff_cache/

# Pyre type checker
.pyre/

Expand Down Expand Up @@ -184,3 +187,8 @@ workflows/work/

# Temp code file
draft.py

# Exculde subject data folders (symlinks)
sub-*

.python-version
19 changes: 19 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.8
hooks:
- id: ruff-format
- id: ruff
args: [--fix]
- repo: local
hooks:
- id: ty
name: ty (type check)
entry: bash -c 'uv run ty check || true'
language: system
types: [python]
files: ^linumpy/
pass_filenames: false
# Non-blocking (advisory) until the existing error baseline is resolved.
# Switch entry to just "uv run ty check" once errors reach zero.
verbose: true
11 changes: 5 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
FROM python:3.12

COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/

WORKDIR /linumpy/

ENV PYTHONDONTWRITEBYTECODE=1
Expand All @@ -14,11 +16,8 @@ RUN apt-get update && apt-get install -y \
zip \
&& rm -rf /var/lib/apt/lists/*

# Upgrade pip, setuptools and wheel
RUN pip install --upgrade pip setuptools wheel build

# Install with verbose output
# Install with uv
COPY linumpy ./linumpy
COPY scripts ./scripts
COPY pyproject.toml requirements.txt README.md setup.py ./
RUN pip install --no-cache-dir -v -e .
COPY pyproject.toml uv.lock README.md ./
RUN uv sync --frozen --no-dev
29 changes: 18 additions & 11 deletions linumpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
# Configure thread limits FIRST, before any numerical libraries are imported
import os as _os

from linumpy._thread_config import (
configure_thread_limits,
apply_threadpool_limits,
configure_all_libraries,
configure_sitk,
apply_threadpool_limits as apply_threadpool_limits,
)
from linumpy._thread_config import (
configure_all_libraries as configure_all_libraries,
)
from linumpy._thread_config import (
configure_sitk as configure_sitk,
)
from linumpy._thread_config import (
configure_thread_limits as configure_thread_limits,
)

import os as _os
from pathlib import Path as _Path


def get_home():
""" Set a user-writeable file-system location to put files. """
if 'LINUMPY_HOME' in _os.environ:
return _os.environ['LINUMPY_HOME']
return _os.path.join(_os.path.expanduser('~'), '.linumpy')
"""Set a user-writeable file-system location to put files."""
if "LINUMPY_HOME" in _os.environ:
return _os.environ["LINUMPY_HOME"]
return str(_Path.home() / ".linumpy")


def get_root():
return _os.path.realpath(f"{_os.path.dirname(_os.path.abspath(__file__))}/..")
return str(_Path(__file__).resolve().parent.parent)


LINUMPY_HOME = get_home()
Expand Down
115 changes: 61 additions & 54 deletions linumpy/_thread_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
Thread configuration module for linumpy.

Expand Down Expand Up @@ -30,6 +29,7 @@
To ensure proper limiting, scripts should call configure_all_libraries() after imports.
"""

import contextlib
import multiprocessing
import os
import sys
Expand All @@ -47,18 +47,21 @@ def get_max_threads():
"""
total_cpus = multiprocessing.cpu_count()

try:
# Check for explicit max CPUs limit
max_cpus = os.environ.get('LINUMPY_MAX_CPUS')
if max_cpus is not None:
# Check for explicit max CPUs limit
max_cpus = os.environ.get("LINUMPY_MAX_CPUS")
if max_cpus is not None:
try:
return max(1, min(int(max_cpus), total_cpus))
except ValueError:
pass

# Check for reserved CPUs
reserved = os.environ.get('LINUMPY_RESERVED_CPUS')
if reserved is not None:
# Check for reserved CPUs
reserved = os.environ.get("LINUMPY_RESERVED_CPUS")
if reserved is not None:
try:
return max(1, total_cpus - int(reserved))
except ValueError:
pass
except ValueError:
pass

# Default: use all CPUs
return total_cpus
Expand All @@ -79,42 +82,42 @@ def configure_thread_limits():
max_threads = get_max_threads()

# If OMP_NUM_THREADS is already set, use that value instead
if 'OMP_NUM_THREADS' in os.environ:
try:
max_threads = int(os.environ['OMP_NUM_THREADS'])
except ValueError:
pass
if "OMP_NUM_THREADS" in os.environ:
with contextlib.suppress(ValueError):
max_threads = int(os.environ["OMP_NUM_THREADS"])

# Set environment variables for all common threading libraries
# Set ALL of them unconditionally to ensure consistency
thread_vars = [
'OMP_NUM_THREADS', # OpenMP (used by numpy, scipy, etc.)
'MKL_NUM_THREADS', # Intel MKL
'OPENBLAS_NUM_THREADS', # OpenBLAS
'VECLIB_MAXIMUM_THREADS', # macOS Accelerate
'NUMEXPR_NUM_THREADS', # NumExpr
'NUMBA_NUM_THREADS', # Numba
'GOTO_NUM_THREADS', # GotoBLAS
'BLIS_NUM_THREADS', # BLIS
'ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS', # SimpleITK/ITK
'XLA_FLAGS', # JAX/XLA thread pool (set below with special format)
"OMP_NUM_THREADS", # OpenMP (used by numpy, scipy, etc.)
"MKL_NUM_THREADS", # Intel MKL
"OPENBLAS_NUM_THREADS", # OpenBLAS
"VECLIB_MAXIMUM_THREADS", # macOS Accelerate
"NUMEXPR_NUM_THREADS", # NumExpr
"NUMBA_NUM_THREADS", # Numba
"GOTO_NUM_THREADS", # GotoBLAS
"BLIS_NUM_THREADS", # BLIS
"ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS", # SimpleITK/ITK
"XLA_FLAGS", # JAX/XLA thread pool (set below with special format)
]

for var in thread_vars:
if var == 'XLA_FLAGS':
if var == "XLA_FLAGS":
# XLA flags use a special format
# This limits JAX's XLA thread pool (used by BaSiCPy)
xla_flags = os.environ.get('XLA_FLAGS', '')
if f'--xla_cpu_multi_thread_eigen=false' not in xla_flags:
xla_flags = os.environ.get("XLA_FLAGS", "")
if "--xla_cpu_multi_thread_eigen=false" not in xla_flags:
# Disable multi-threading in XLA's Eigen backend for better control
new_flags = f'{xla_flags} --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads={max_threads}'.strip()
os.environ['XLA_FLAGS'] = new_flags
new_flags = (
f"{xla_flags} --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads={max_threads}".strip()
)
os.environ["XLA_FLAGS"] = new_flags
else:
os.environ[var] = str(max_threads)

# Also set dask configuration via environment variable
# This limits dask's thread pool before dask is imported
os.environ['DASK_NUM_WORKERS'] = str(max_threads)
os.environ["DASK_NUM_WORKERS"] = str(max_threads)

return max_threads

Expand All @@ -126,10 +129,11 @@ def configure_dask():
"""
try:
import dask
max_threads = int(os.environ.get('OMP_NUM_THREADS', multiprocessing.cpu_count()))

max_threads = int(os.environ.get("OMP_NUM_THREADS", multiprocessing.cpu_count()))
dask.config.set(num_workers=max_threads)
dask.config.set(scheduler='threads') # Use thread scheduler, not process
dask.config.set({'array.slicing.split_large_chunks': False})
dask.config.set(scheduler="threads") # Use thread scheduler, not process
dask.config.set({"array.slicing.split_large_chunks": False})
except ImportError:
pass

Expand All @@ -145,7 +149,8 @@ def configure_sitk():
"""
try:
import SimpleITK as sitk
max_threads = int(os.environ.get('OMP_NUM_THREADS', multiprocessing.cpu_count()))

max_threads = int(os.environ.get("OMP_NUM_THREADS", multiprocessing.cpu_count()))
sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(max_threads)
except ImportError:
pass
Expand All @@ -165,7 +170,7 @@ def apply_threadpool_limits():
from threadpoolctl import threadpool_limits

# Get the configured thread limit
max_threads = int(os.environ.get('OMP_NUM_THREADS', multiprocessing.cpu_count()))
max_threads = int(os.environ.get("OMP_NUM_THREADS", multiprocessing.cpu_count()))

# Apply limits globally - this returns a context manager but also applies immediately
limiter = threadpool_limits(limits=max_threads)
Expand All @@ -192,20 +197,21 @@ def configure_all_libraries():
"""
global _thread_config_applied

max_threads = int(os.environ.get('OMP_NUM_THREADS', multiprocessing.cpu_count()))
max_threads = int(os.environ.get("OMP_NUM_THREADS", multiprocessing.cpu_count()))

# Configure SimpleITK if imported (CRITICAL - major source of CPU spikes)
if 'SimpleITK' in sys.modules:
if "SimpleITK" in sys.modules:
configure_sitk()

# Configure dask if imported
if 'dask' in sys.modules:
if "dask" in sys.modules:
configure_dask()

# Configure numba if imported
if 'numba' in sys.modules:
if "numba" in sys.modules:
try:
from numba import set_num_threads

set_num_threads(max_threads)
except (ImportError, Exception):
pass
Expand All @@ -226,31 +232,32 @@ def get_thread_info():
dict: Thread configuration information
"""
info = {
'total_cpus': multiprocessing.cpu_count(),
'configured_threads': int(os.environ.get('OMP_NUM_THREADS', multiprocessing.cpu_count())),
'env_vars': {},
'libraries': {},
"total_cpus": multiprocessing.cpu_count(),
"configured_threads": int(os.environ.get("OMP_NUM_THREADS", multiprocessing.cpu_count())),
"env_vars": {},
"libraries": {},
}

# Check environment variables
for var in ['OMP_NUM_THREADS', 'MKL_NUM_THREADS', 'OPENBLAS_NUM_THREADS',
'LINUMPY_MAX_CPUS', 'LINUMPY_RESERVED_CPUS']:
info['env_vars'][var] = os.environ.get(var, 'NOT SET')
for var in ["OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", "LINUMPY_MAX_CPUS", "LINUMPY_RESERVED_CPUS"]:
info["env_vars"][var] = os.environ.get(var, "NOT SET")

# Check SimpleITK
if 'SimpleITK' in sys.modules:
if "SimpleITK" in sys.modules:
try:
import SimpleITK as sitk
info['libraries']['SimpleITK'] = sitk.ProcessObject.GetGlobalDefaultNumberOfThreads()

info["libraries"]["SimpleITK"] = sitk.ProcessObject.GetGlobalDefaultNumberOfThreads()
except Exception:
info['libraries']['SimpleITK'] = 'ERROR'
info["libraries"]["SimpleITK"] = "ERROR"

# Check threadpoolctl
try:
from threadpoolctl import threadpool_info
info['libraries']['threadpoolctl'] = threadpool_info()

info["libraries"]["threadpoolctl"] = threadpool_info()
except ImportError:
info['libraries']['threadpoolctl'] = 'NOT INSTALLED'
info["libraries"]["threadpoolctl"] = "NOT INSTALLED"

return info

Expand All @@ -261,10 +268,10 @@ def print_thread_info():
print(f"CPU cores: {info['total_cpus']}")
print(f"Configured threads: {info['configured_threads']}")
print("Environment variables:")
for var, val in info['env_vars'].items():
for var, val in info["env_vars"].items():
print(f" {var}: {val}")
print("Library configurations:")
for lib, val in info['libraries'].items():
for lib, val in info["libraries"].items():
print(f" {lib}: {val}")


Expand Down
19 changes: 15 additions & 4 deletions linumpy/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from .allen import *
from .data_io import *
from .zarr import *
from .npz import write_numpy, read_numpy, read_numpy_data, read_numpy_metadata
from .allen import * # noqa: F403
from .data_io import * # noqa: F403
from .npz import (
read_numpy as read_numpy,
)
from .npz import (
read_numpy_data as read_numpy_data,
)
from .npz import (
read_numpy_metadata as read_numpy_metadata,
)
from .npz import (
write_numpy as write_numpy,
)
from .zarr import * # noqa: F403
Loading
Loading