Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Commit a9378fb

Browse files
authored
Update rotmg interface to handle issue in OpenCL CPU support (#532)
* Update rotmg implementation * Add conditional check on mem_type in use * Add dependencies to copy_y1 operation * Add new interface to library and add related tests.
1 parent 18a32a8 commit a9378fb

File tree

3 files changed

+192
-124
lines changed

3 files changed

+192
-124
lines changed

src/interface/blas1/rotmg.cpp.in

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,22 @@ template typename SB_Handle::event_t _rotmg(
7272
BufferIterator<${DATA_TYPE}> _y1, BufferIterator<${DATA_TYPE}> _param,
7373
const typename SB_Handle::event_t& dependencies);
7474

75+
template typename SB_Handle::event_t _rotmg(
76+
SB_Handle& sb_handle, BufferIterator<${DATA_TYPE}> _d1,
77+
BufferIterator<${DATA_TYPE}> _d2, BufferIterator<${DATA_TYPE}> _x1,
78+
${DATA_TYPE} _y1, BufferIterator<${DATA_TYPE}> _param,
79+
const typename SB_Handle::event_t& dependencies);
80+
7581
#ifdef SB_ENABLE_USM
7682
template typename SB_Handle::event_t _rotmg(
7783
SB_Handle& sb_handle, ${DATA_TYPE} * _d1, ${DATA_TYPE} * _d2,
7884
${DATA_TYPE} * _x1, ${DATA_TYPE} * _y1, ${DATA_TYPE} * _param,
7985
const typename SB_Handle::event_t& dependencies);
86+
87+
template typename SB_Handle::event_t _rotmg(
88+
SB_Handle& sb_handle, ${DATA_TYPE} * _d1, ${DATA_TYPE} * _d2,
89+
${DATA_TYPE} * _x1, ${DATA_TYPE} _y1, ${DATA_TYPE} * _param,
90+
const typename SB_Handle::event_t& dependencies);
8091
#endif
8192

8293
} // namespace internal

src/interface/blas1_interface.hpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -811,14 +811,37 @@ typename sb_handle_t::event_t _rotmg(
811811
auto d1_view = make_vector_view(_d1, inc, vector_size);
812812
auto d2_view = make_vector_view(_d2, inc, vector_size);
813813
auto x1_view = make_vector_view(_x1, inc, vector_size);
814-
auto y1_view = make_vector_view(_y1, inc, vector_size);
815814
auto param_view = make_vector_view(_param, inc, param_size);
816815

817-
auto operation =
818-
Rotmg<decltype(d1_view)>(d1_view, d2_view, x1_view, y1_view, param_view);
819-
auto ret = sb_handle.execute(operation, _dependencies);
816+
if constexpr (std::is_arithmetic_v<container_3_t>) {
817+
constexpr helper::AllocType mem_type = std::is_pointer_v<container_0_t>
818+
? helper::AllocType::usm
819+
: helper::AllocType::buffer;
820+
auto _y1_tmp = blas::helper::allocate<mem_type, container_3_t>(
821+
1, sb_handle.get_queue());
820822

821-
return ret;
823+
auto copy_y1 = blas::helper::copy_to_device(sb_handle.get_queue(), &_y1,
824+
_y1_tmp, 1, _dependencies);
825+
826+
auto y1_view = make_vector_view(_y1_tmp, inc, vector_size);
827+
auto operation = Rotmg<decltype(d1_view)>(d1_view, d2_view, x1_view,
828+
y1_view, param_view);
829+
830+
auto operator_event =
831+
sb_handle.execute(operation, typename sb_handle_t::event_t{copy_y1});
832+
if constexpr (mem_type != helper::AllocType::buffer) {
833+
// This wait is necessary to free the temporary memory created above and
834+
// avoiding the host_task
835+
operator_event[0].wait();
836+
sycl::free(_y1_tmp, sb_handle.get_queue());
837+
}
838+
return operator_event;
839+
} else {
840+
auto y1_view = make_vector_view(_y1, inc, vector_size);
841+
auto operation = Rotmg<decltype(d1_view)>(d1_view, d2_view, x1_view,
842+
y1_view, param_view);
843+
return sb_handle.execute(operation, _dependencies);
844+
}
822845
}
823846

824847
/**

0 commit comments

Comments
 (0)