Skip to content

Commit cc6d03c

Browse files
authored
[BP] Check cub errors (dmlc#10721) (dmlc#10903)
This backport cherry picks the specific fix in the row partitioner instead of the entire patch.
1 parent 1c61752 commit cc6d03c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/tree/gpu_hist/row_partitioner.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,13 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
134134
});
135135
size_t temp_bytes = 0;
136136
if (tmp->empty()) {
137-
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
138-
IndexFlagOp(), total_rows);
137+
dh::safe_cuda(cub::DeviceScan::InclusiveScan(
138+
nullptr, temp_bytes, input_iterator, discard_write_iterator, IndexFlagOp(), total_rows));
139139
tmp->resize(temp_bytes);
140140
}
141141
temp_bytes = tmp->size();
142-
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
143-
discard_write_iterator, IndexFlagOp(), total_rows);
142+
dh::safe_cuda(cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
143+
discard_write_iterator, IndexFlagOp(), total_rows));
144144

145145
constexpr int kBlockSize = 256;
146146

0 commit comments

Comments
 (0)