Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions projects/hipfft/shared/rocfft_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,17 +355,33 @@ class rocfft_params_base : public fft_params

// Return the number of expected callback entries for supplied
// fields.
static size_t expected_callback_count(const std::vector<fft_field>& fields)
size_t expected_callback_count(const std::vector<fft_field>& fields)
{
// If fields are not specified, we consider the input or
// output to have a single brick (and thus expect a single
// callback entry)
if(fields.empty())
return 1;
return std::accumulate(fields.begin(),
fields.end(),
static_cast<size_t>(0),
[](size_t s, const fft_field& f) { return s + f.bricks.size(); });

int mpi_rank = 0;
#ifdef ROCFFT_MPI_ENABLE
if(mp_lib == fft_mp_lib_mpi)
{
MPI_Comm_rank(*static_cast<MPI_Comm*>(mp_comm), &mpi_rank);
}
#endif

// count the number of bricks on this rank
size_t expected_callbacks = 0;
for(const auto& f : fields)
{
for(const auto& b : f.bricks)
{
if(b.rank == mpi_rank)
++expected_callbacks;
}
}
return expected_callbacks;
}

fft_status set_callbacks(std::vector<void*>* load_cb_func,
Expand Down
34 changes: 26 additions & 8 deletions projects/hipfft/shared/test_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "test_callbacks.h"
#include "rocfft_complex.h"

#ifdef ROCFFT_MPI_ENABLE
#include <mpi.h>
#endif

// load/store callbacks - cbdata in each is actually a scalar double
// with a number to apply to each element
template <typename Tdata>
Expand Down Expand Up @@ -592,7 +596,7 @@ void apply_load_callback(const fft_params& params, std::vector<hostbuf>& input)
}
}

// For a specified rank, get a vector of load callback function +
// For the current rank, get a vector of load callback function +
// data pointers. The pointers need to be in the order that
// fields+bricks were specified to the FFT plan. Pointers need to be
// copied to the host from the device specified by the respective
Expand All @@ -602,9 +606,16 @@ void get_rank_load_callbacks(const fft_params& params,
std::vector<void*>& load_cb_data,
callback_hip_error_handler runtime_err_handler,
bool round_trip_inverse,
std::vector<gpubuf_t<callback_test_data>>& all_cb_data,
int rank)
std::vector<gpubuf_t<callback_test_data>>& all_cb_data)
{
int mpi_rank = 0;
#ifdef ROCFFT_MPI_ENABLE
if(params.mp_lib == fft_params::fft_mp_lib_mpi)
{
MPI_Comm_rank(*static_cast<MPI_Comm*>(params.mp_comm), &mpi_rank);
}
#endif

// Copy callback pointer from current device and add to output vec
auto add_load_cb = [&]() {
void* load_cb_host = get_load_callback_host(
Expand Down Expand Up @@ -663,7 +674,7 @@ void get_rank_load_callbacks(const fft_params& params,
// on this rank
for(size_t i = 0; i < params.ifields.front().bricks.size(); ++i)
{
if(params.ifields.front().bricks[i].rank != rank)
if(params.ifields.front().bricks[i].rank != mpi_rank)
continue;

// load cb for this brick's device
Expand All @@ -673,7 +684,7 @@ void get_rank_load_callbacks(const fft_params& params,
}
}

// For a specified rank, get a vector of store callback function +
// For the current rank, get a vector of store callback function +
// data pointers. The pointers need to be in the order that
// fields+bricks were specified to the FFT plan. Pointers need to be
// copied to the host from the device specified by the respective
Expand All @@ -683,9 +694,16 @@ void get_rank_store_callbacks(const fft_params& params,
std::vector<void*>& store_cb_data,
callback_hip_error_handler runtime_err_handler,
bool round_trip_inverse,
std::vector<gpubuf_t<callback_test_data>>& all_cb_data,
int rank)
std::vector<gpubuf_t<callback_test_data>>& all_cb_data)
{
int mpi_rank = 0;
#ifdef ROCFFT_MPI_ENABLE
if(params.mp_lib == fft_params::fft_mp_lib_mpi)
{
MPI_Comm_rank(*static_cast<MPI_Comm*>(params.mp_comm), &mpi_rank);
}
#endif

// Copy callback pointer from current device and add to output vec
auto add_store_cb = [&]() {
void* store_cb_host = get_store_callback_host(
Expand Down Expand Up @@ -746,7 +764,7 @@ void get_rank_store_callbacks(const fft_params& params,
// on this rank
for(size_t i = 0; i < params.ofields.front().bricks.size(); ++i)
{
if(params.ofields.front().bricks[i].rank != rank)
if(params.ofields.front().bricks[i].rank != mpi_rank)
continue;

// store cb for this brick's device
Expand Down
6 changes: 2 additions & 4 deletions projects/hipfft/shared/test_callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ void get_rank_load_callbacks(const fft_params& params,
std::vector<void*>& load_cb_data,
callback_hip_error_handler runtime_err_handler,
bool round_trip_inverse,
std::vector<gpubuf_t<callback_test_data>>& all_cb_data,
int rank = 0);
std::vector<gpubuf_t<callback_test_data>>& all_cb_data);

// Collect store callback function and data pointers for the given
// params. We'd expect N pointers for N output bricks on the current
Expand All @@ -76,8 +75,7 @@ void get_rank_store_callbacks(const fft_params& params,
std::vector<void*>& store_cb_data,
callback_hip_error_handler runtime_err_handler,
bool round_trip_inverse,
std::vector<gpubuf_t<callback_test_data>>& all_cb_data,
int rank = 0);
std::vector<gpubuf_t<callback_test_data>>& all_cb_data);

// Execute the load/store callback function on a host buffer, to
// ensure that the reference host FFT is comparable to a device FFT
Expand Down
1 change: 1 addition & 0 deletions projects/rocfft/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Documentation for rocFFT is available at
* Fixed potential launch failure of data generation kernels in test and benchmark programs.
* Fixed incorrect results on some strided real-complex FFTs on gfx90a.
* Fixed incorrect results on some even-length real FFTs that have odd-length strides on higher dimensions.
* Fixed callbacks on MPI transforms, when not all ranks have the same number of data bricks.

## rocFFT 1.0.36 for ROCm 7.2.0

Expand Down
15 changes: 9 additions & 6 deletions projects/rocfft/library/src/callback_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
#include "transform.h"

std::map<int, device_callback_t> DeviceCallbackMap(const rocfft_execution_info_t* info,
const rocfft_plan_description_t& desc)
const rocfft_plan_description_t& desc,
int local_comm_rank)
{
// tolerate user not providing an execution_info
rocfft_execution_info_t exec_info;
Expand All @@ -36,16 +37,18 @@ std::map<int, device_callback_t> DeviceCallbackMap(const rocfft_execution_info_t

std::map<int, device_callback_t> callbacks;

auto set_field_callback = [&callbacks](const std::vector<rocfft_field_t>& fields,
void** src_fn,
void** src_data,
bool load) {
auto set_field_callback = [=, &callbacks](const std::vector<rocfft_field_t>& fields,
void** src_fn,
void** src_data,
bool load) {
size_t src_idx = 0;
for(const auto& f : fields)
{
for(const auto& b : f.bricks)
{
int device_id = b.location.device;
if(b.location.comm_rank != local_comm_rank)
continue;

if(load)
{
Expand Down Expand Up @@ -119,4 +122,4 @@ std::map<int, device_callback_t> DeviceCallbackMap(const rocfft_execution_info_t
}

return callbacks;
}
}
3 changes: 2 additions & 1 deletion projects/rocfft/library/src/include/callback_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@ struct device_callback_t
struct rocfft_execution_info_t;
struct rocfft_plan_description_t;
std::map<int, device_callback_t> DeviceCallbackMap(const rocfft_execution_info_t* info,
const rocfft_plan_description_t& desc);
const rocfft_plan_description_t& desc,
int local_comm_rank);
#endif
2 changes: 1 addition & 1 deletion projects/rocfft/library/src/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ void rocfft_plan_t::Execute(void* in_buffer[], void* out_buffer[], rocfft_execut

LogSortedPlan(sortedIdx);

auto callbacks = DeviceCallbackMap(info, desc);
auto callbacks = DeviceCallbackMap(info, desc, local_comm_rank);

for(auto i = sortedIdx.begin(); i != sortedIdx.end(); ++i)
{
Expand Down
26 changes: 21 additions & 5 deletions projects/rocfft/shared/rocfft_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,17 +355,33 @@ class rocfft_params_base : public fft_params

// Return the number of expected callback entries for supplied
// fields.
static size_t expected_callback_count(const std::vector<fft_field>& fields)
size_t expected_callback_count(const std::vector<fft_field>& fields)
{
// If fields are not specified, we consider the input or
// output to have a single brick (and thus expect a single
// callback entry)
if(fields.empty())
return 1;
return std::accumulate(fields.begin(),
fields.end(),
static_cast<size_t>(0),
[](size_t s, const fft_field& f) { return s + f.bricks.size(); });

int mpi_rank = 0;
#ifdef ROCFFT_MPI_ENABLE
if(mp_lib == fft_mp_lib_mpi)
{
MPI_Comm_rank(*static_cast<MPI_Comm*>(mp_comm), &mpi_rank);
}
#endif

// count the number of bricks on this rank
size_t expected_callbacks = 0;
for(const auto& f : fields)
{
for(const auto& b : f.bricks)
{
if(b.rank == mpi_rank)
++expected_callbacks;
}
}
return expected_callbacks;
}

fft_status set_callbacks(std::vector<void*>* load_cb_func,
Expand Down
34 changes: 26 additions & 8 deletions projects/rocfft/shared/test_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "test_callbacks.h"
#include "rocfft_complex.h"

#ifdef ROCFFT_MPI_ENABLE
#include <mpi.h>
#endif

// load/store callbacks - cbdata in each is actually a scalar double
// with a number to apply to each element
template <typename Tdata>
Expand Down Expand Up @@ -592,7 +596,7 @@ void apply_load_callback(const fft_params& params, std::vector<hostbuf>& input)
}
}

// For a specified rank, get a vector of load callback function +
// For the current rank, get a vector of load callback function +
// data pointers. The pointers need to be in the order that
// fields+bricks were specified to the FFT plan. Pointers need to be
// copied to the host from the device specified by the respective
Expand All @@ -602,9 +606,16 @@ void get_rank_load_callbacks(const fft_params& params,
std::vector<void*>& load_cb_data,
callback_hip_error_handler runtime_err_handler,
bool round_trip_inverse,
std::vector<gpubuf_t<callback_test_data>>& all_cb_data,
int rank)
std::vector<gpubuf_t<callback_test_data>>& all_cb_data)
{
int mpi_rank = 0;
#ifdef ROCFFT_MPI_ENABLE
if(params.mp_lib == fft_params::fft_mp_lib_mpi)
{
MPI_Comm_rank(*static_cast<MPI_Comm*>(params.mp_comm), &mpi_rank);
}
#endif

// Copy callback pointer from current device and add to output vec
auto add_load_cb = [&]() {
void* load_cb_host = get_load_callback_host(
Expand Down Expand Up @@ -663,7 +674,7 @@ void get_rank_load_callbacks(const fft_params& params,
// on this rank
for(size_t i = 0; i < params.ifields.front().bricks.size(); ++i)
{
if(params.ifields.front().bricks[i].rank != rank)
if(params.ifields.front().bricks[i].rank != mpi_rank)
continue;

// load cb for this brick's device
Expand All @@ -673,7 +684,7 @@ void get_rank_load_callbacks(const fft_params& params,
}
}

// For a specified rank, get a vector of store callback function +
// For the current rank, get a vector of store callback function +
// data pointers. The pointers need to be in the order that
// fields+bricks were specified to the FFT plan. Pointers need to be
// copied to the host from the device specified by the respective
Expand All @@ -683,9 +694,16 @@ void get_rank_store_callbacks(const fft_params& params,
std::vector<void*>& store_cb_data,
callback_hip_error_handler runtime_err_handler,
bool round_trip_inverse,
std::vector<gpubuf_t<callback_test_data>>& all_cb_data,
int rank)
std::vector<gpubuf_t<callback_test_data>>& all_cb_data)
{
int mpi_rank = 0;
#ifdef ROCFFT_MPI_ENABLE
if(params.mp_lib == fft_params::fft_mp_lib_mpi)
{
MPI_Comm_rank(*static_cast<MPI_Comm*>(params.mp_comm), &mpi_rank);
}
#endif

// Copy callback pointer from current device and add to output vec
auto add_store_cb = [&]() {
void* store_cb_host = get_store_callback_host(
Expand Down Expand Up @@ -746,7 +764,7 @@ void get_rank_store_callbacks(const fft_params& params,
// on this rank
for(size_t i = 0; i < params.ofields.front().bricks.size(); ++i)
{
if(params.ofields.front().bricks[i].rank != rank)
if(params.ofields.front().bricks[i].rank != mpi_rank)
continue;

// store cb for this brick's device
Expand Down
6 changes: 2 additions & 4 deletions projects/rocfft/shared/test_callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ void get_rank_load_callbacks(const fft_params& params,
std::vector<void*>& load_cb_data,
callback_hip_error_handler runtime_err_handler,
bool round_trip_inverse,
std::vector<gpubuf_t<callback_test_data>>& all_cb_data,
int rank = 0);
std::vector<gpubuf_t<callback_test_data>>& all_cb_data);

// Collect store callback function and data pointers for the given
// params. We'd expect N pointers for N output bricks on the current
Expand All @@ -76,8 +75,7 @@ void get_rank_store_callbacks(const fft_params& params,
std::vector<void*>& store_cb_data,
callback_hip_error_handler runtime_err_handler,
bool round_trip_inverse,
std::vector<gpubuf_t<callback_test_data>>& all_cb_data,
int rank = 0);
std::vector<gpubuf_t<callback_test_data>>& all_cb_data);

// Execute the load/store callback function on a host buffer, to
// ensure that the reference host FFT is comparable to a device FFT
Expand Down
Loading