Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2fd3c8a
Update JAX Binding to use FFI
ASKabalan Mar 26, 2025
2b591ca
Update JAX Primitive to accept is_linear
ASKabalan Mar 26, 2025
8fe86c2
Update healpix_ffts to use new FFI lowered cuda healpix ffts
ASKabalan Mar 26, 2025
933ac2a
Update benchmarks
ASKabalan Mar 26, 2025
e2cc68c
Update Pyproject.toml and build to include FFI headers
ASKabalan Mar 28, 2025
b5cbeac
Implement VMAP and transpose rules for cuda primitive
ASKabalan Mar 28, 2025
9e0f121
Update JAX binding layer
ASKabalan Mar 28, 2025
92fe6a0
add vmap jacrev and jacfwd tests
ASKabalan Mar 28, 2025
a70b262
Fix build without CUDA NVCC
ASKabalan Mar 28, 2025
0e03787
Implement requested changes
ASKabalan Apr 16, 2025
6f6c07e
Update tests/test_healpix_ffts.py
ASKabalan Apr 16, 2025
f8a9a6d
Merge remote-tracking branch 'origin/main' into ASKabalan
ASKabalan Jun 19, 2025
866d1f2
don't include ffi headers if cuda is not available
ASKabalan Jun 19, 2025
a83dbd1
Fix memory illegal access issue
ASKabalan Jun 28, 2025
fd7860e
remove strict requirement on JAX being less than 0.6.0
ASKabalan Jun 28, 2025
d29af9b
format
ASKabalan Jun 28, 2025
1ac3541
removubg s2fft callbacks
ASKabalan Jun 30, 2025
b75c0ce
code works
ASKabalan Jul 2, 2025
00b169c
Updating CUDA extension and removing CUFFT callbacks
ASKabalan Jul 2, 2025
9775bba
remvove callback params workspace
ASKabalan Jul 2, 2025
fb8d0df
format
ASKabalan Jul 2, 2025
850cd43
Bump pypa/cibuildwheel from 2.23.3 to 3.0.0 (#311)
dependabot[bot] Jul 7, 2025
25b2cc1
Update `python_requires` and test matrix to support Python 3.11+ (#305)
matt-graham Jul 8, 2025
ba5a531
Update Python version used in docs workflow (#314)
matt-graham Jul 8, 2025
bfe89dc
Bump pypa/cibuildwheel from 3.0.0 to 3.0.1 (#313)
dependabot[bot] Jul 21, 2025
64b1ceb
Bump pypa/cibuildwheel from 3.0.1 to 3.1.3 (#318)
dependabot[bot] Aug 11, 2025
2e52da3
Update custom_ops.py (#315)
kmulderdas Aug 11, 2025
f6cd7f4
Bump actions/checkout from 4.2.2 to 5.0.0 (#321)
dependabot[bot] Aug 27, 2025
5152e2c
Bump actions/download-artifact from 4 to 5 (#322)
dependabot[bot] Aug 27, 2025
ac1609d
Bump pypa/cibuildwheel from 3.1.3 to 3.1.4 (#323)
dependabot[bot] Aug 27, 2025
928ea12
Fix race condition error and update notebook
ASKabalan Nov 11, 2025
bca4837
fix pyproject.toml
ASKabalan Nov 11, 2025
50e2840
Merge remote-tracking branch 'upstream/main' into ASKabalan
ASKabalan Nov 11, 2025
757e022
Remove nano_bind helpers reference for License section
ASKabalan Nov 11, 2025
6400681
Fuse normalize and shift kernels for both forward and inverse transfo…
ASKabalan Nov 11, 2025
77cbc96
format
ASKabalan Nov 11, 2025
2a2e6e7
Update notebooks/JAX_CUDA_HEALPix.ipynb
ASKabalan Nov 11, 2025
3bcb69a
fix pre-commit
ASKabalan Nov 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ repos:
hooks:
- id: ruff
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
hooks:
- id: clang-format
files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$'
exclude: '^third_party/|/pybind11/'
name: clang-format
95 changes: 60 additions & 35 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ set(CMAKE_CUDA_STANDARD 17)

# Set default build type to Release
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
set(CMAKE_BUILD_TYPE
Release
CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release"
"MinSizeRel" "RelWithDebInfo")
endif()

# Check for CUDA
Expand All @@ -23,35 +26,48 @@ if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA compiler found: ${CMAKE_CUDA_COMPILER}")

if(NOT SKBUILD)
message(FATAL_ERROR "Building standalone project directly without pip install is not supported"
"Please use pip install to build the project")
message(
FATAL_ERROR
"Building standalone project directly without pip install is not supported"
"Please use pip install to build the project")
else()
find_package(CUDAToolkit REQUIRED)

find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
# Add the executable
find_package(
Python 3.8 REQUIRED
COMPONENTS Interpreter Development.Module
OPTIONAL_COMPONENTS Development.SABIModule)
execute_process(
COMMAND "${Python_EXECUTABLE}" "-c"
"from jax import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE XLA_DIR)
message(STATUS "XLA include directory: ${XLA_DIR}")

# Detect the installed nanobind package and import it into CMake
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)

nanobind_add_module(_s2fft STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft.cu
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_callbacks.cu
${CMAKE_CURRENT_LIST_DIR}/lib/src/plan_cache.cc
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_kernels.cu
)
find_package(nanobind CONFIG REQUIRED)

nanobind_add_module(
_s2fft
STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft.cu
${CMAKE_CURRENT_LIST_DIR}/lib/src/plan_cache.cc
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_kernels.cu)

target_link_libraries(_s2fft PRIVATE CUDA::cudart_static CUDA::cufft_static CUDA::culibos)
target_include_directories(_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include)
set_target_properties(_s2fft PROPERTIES
LINKER_LANGUAGE CUDA
CUDA_SEPARABLE_COMPILATION ON)
set(CMAKE_CUDA_ARCHITECTURES "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
target_include_directories(
_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include ${XLA_DIR} ${CUDAToolkit_INCLUDE_DIRS})
set_target_properties(_s2fft PROPERTIES LINKER_LANGUAGE CUDA
CUDA_SEPARABLE_COMPILATION ON)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -rdc=true")
set(CMAKE_CUDA_ARCHITECTURES
"70;80;89"
CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
message(STATUS "CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
set_target_properties(_s2fft PROPERTIES CUDA_ARCHITECTURES "${CMAKE_CUDA_ARCHITECTURES}")
set_target_properties(_s2fft PROPERTIES CUDA_ARCHITECTURES
"${CMAKE_CUDA_ARCHITECTURES}")

install(TARGETS _s2fft LIBRARY DESTINATION s2fft_lib)
endif()
Expand All @@ -60,26 +76,35 @@ else()
if(SKBUILD)
message(WARNING "CUDA compiler not found, building without CUDA support")

find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)

# Add the executable
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
COMMAND "${Python_EXECUTABLE}" "-c"
"from jax import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE XLA_DIR)
message(STATUS "XLA include directory: ${XLA_DIR}")

nanobind_add_module(_s2fft STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc
)
# Detect the installed nanobind package and import it into CMake
find_package(nanobind CONFIG REQUIRED)

nanobind_add_module(_s2fft STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc)

target_compile_definitions(_s2fft PRIVATE NO_CUDA_COMPILER)
target_include_directories(_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include)
target_include_directories(
_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include ${XLA_DIR})

install(TARGETS _s2fft LIBRARY DESTINATION s2fft_lib)

else()
message(FATAL_ERROR "Building standalone project directly without pip install is not supported"
"Please use pip install to build the project")
message(
FATAL_ERROR
"Building standalone project directly without pip install is not supported"
"Please use pip install to build the project")
endif()
endif()


165 changes: 165 additions & 0 deletions lib/include/cudastreamhandler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@

/**
* @file cudastreamhandler.hpp
* @brief Singleton class for managing CUDA streams and events.
*
* This header provides a singleton implementation that encapsulates the creation,
* management, and cleanup of CUDA streams and events. It offers functions to fork
* streams, add new streams, and synchronize (join) streams with a given dependency.
*
* Usage example:
* @code
* #include "cudastreamhandler.hpp"
*
* int main() {
* // Create a handler instance
* CudaStreamHandler handler;
*
* // Fork 4 streams dependent on a given stream 'stream_main'
* handler.Fork(stream_main, 4);
*
* // Do work on the forked streams...
*
* // Join the streams back to 'stream_main'
* handler.join(stream_main);
*
* return 0;
* }
* @endcode
*
* Author: Wassim KABALAN
*/

#ifndef CUDASTREAMHANDLER_HPP
#define CUDASTREAMHANDLER_HPP

#include <algorithm>
#include <atomic>
#include <cuda_runtime.h>
#include <stdexcept>
#include <thread>
#include <vector>

// Singleton class managing CUDA streams and events
class CudaStreamHandlerImpl {
public:
static CudaStreamHandlerImpl &instance() {
static CudaStreamHandlerImpl instance;
return instance;
}

void AddStreams(int numStreams) {
if (numStreams > m_streams.size()) {
int streamsToAdd = numStreams - m_streams.size();
m_streams.resize(numStreams);
std::generate(m_streams.end() - streamsToAdd, m_streams.end(), []() {
cudaStream_t stream;
cudaStreamCreate(&stream);
return stream;
});
}
}

void join(cudaStream_t finalStream) {
std::for_each(m_streams.begin(), m_streams.end(), [this, finalStream](cudaStream_t stream) {
cudaEvent_t event;
cudaEventCreate(&event);
cudaEventRecord(event, stream);
cudaStreamWaitEvent(finalStream, event, 0);
m_events.push_back(event);
});

if (!cleanup_thread.joinable()) {
stop_thread.store(false);
cleanup_thread = std::thread([this]() { this->AsyncEventCleanup(); });
}
}

// Fork function to add streams and set dependency on a given stream
void Fork(cudaStream_t dependentStream, int N) {
AddStreams(N); // Add N streams

// Set dependency on the provided stream
std::for_each(m_streams.end() - N, m_streams.end(), [this, dependentStream](cudaStream_t stream) {
cudaEvent_t event;
cudaEventCreate(&event);
cudaEventRecord(event, dependentStream);
cudaStreamWaitEvent(stream, event, 0); // Set the stream to wait on the event
m_events.push_back(event);
});
}

auto getIterator() { return StreamIterator(m_streams.begin(), m_streams.end()); }

~CudaStreamHandlerImpl() {
stop_thread.store(true);
if (cleanup_thread.joinable()) {
cleanup_thread.join();
}

std::for_each(m_streams.begin(), m_streams.end(), cudaStreamDestroy);
std::for_each(m_events.begin(), m_events.end(), cudaEventDestroy);
}

// Custom Iterator class to iterate over streams
class StreamIterator {
public:
StreamIterator(std::vector<cudaStream_t>::iterator begin, std::vector<cudaStream_t>::iterator end)
: current(begin), end(end) {}

cudaStream_t next() {
if (current == end) {
throw std::out_of_range("No more streams.");
}
return *current++;
}

bool hasNext() const { return current != end; }

private:
std::vector<cudaStream_t>::iterator current;
std::vector<cudaStream_t>::iterator end;
};

private:
CudaStreamHandlerImpl() : stop_thread(false) {}
CudaStreamHandlerImpl(const CudaStreamHandlerImpl &) = delete;
CudaStreamHandlerImpl &operator=(const CudaStreamHandlerImpl &) = delete;

void AsyncEventCleanup() {
while (!stop_thread.load()) {
std::for_each(m_events.begin(), m_events.end(), [this](cudaEvent_t &event) {
if (cudaEventQuery(event) == cudaSuccess) {
cudaEventDestroy(event);
event = nullptr;
}
});
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}

std::vector<cudaStream_t> m_streams;
std::vector<cudaEvent_t> m_events;
std::thread cleanup_thread;
std::atomic<bool> stop_thread;
};

// Public class for encapsulating the singleton operations
class CudaStreamHandler {
public:
CudaStreamHandler() = default;
~CudaStreamHandler() = default;

void AddStreams(int numStreams) { CudaStreamHandlerImpl::instance().AddStreams(numStreams); }

void join(cudaStream_t finalStream) { CudaStreamHandlerImpl::instance().join(finalStream); }

void Fork(cudaStream_t cudastream, int N) { CudaStreamHandlerImpl::instance().Fork(cudastream, N); }

// Get the custom iterator for CUDA streams
CudaStreamHandlerImpl::StreamIterator getIterator() {
return CudaStreamHandlerImpl::instance().getIterator();
}
};

#endif // CUDASTREAMHANDLER_HPP
76 changes: 0 additions & 76 deletions lib/include/kernel_helpers.h
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this file removed we can remove the comment in README

s2fft/README.md

Lines 350 to 352 in d77e9cb

The file [`lib/include/kernel_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h) is adapted from
[code](https://github.com/dfm/extending-jax/blob/c33869665236877a2ae281f3f5dbff579e8f5b00/lib/kernel_helpers.h) in [a tutorial on extending JAX](https://github.com/dfm/extending-jax) by
[Dan Foreman-Mackey](https://github.com/dfm) and licensed under a [MIT license](https://github.com/dfm/extending-jax/blob/371dca93c6405368fa8e71690afd3968d75f4bac/LICENSE).

This file was deleted.

Loading
Loading