@@ -50,29 +50,28 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
5050
5151template <typename BinIdxType, bool isDense>
5252void GHistIndexMatrix::SetIndexData (::sycl::queue* qu,
53- Context const * ctx,
5453 BinIdxType* index_data,
55- DMatrix *dmat) {
54+ DMatrix *dmat,
55+ size_t nbins,
56+ size_t row_stride) {
5657 if (nbins == 0 ) return ;
5758 const bst_float* cut_values = cut.cut_values_ .ConstDevicePointer ();
5859 const uint32_t * cut_ptrs = cut.cut_ptrs_ .ConstDevicePointer ();
5960 size_t * hit_count_ptr = hit_count.DevicePointer ();
6061
6162 BinIdxType* sort_data = reinterpret_cast <BinIdxType*>(sort_buff.Data ());
6263
64+ ::sycl::event event;
6365 for (auto &batch : dmat->GetBatches <SparsePage>()) {
64- batch.data .SetDevice (ctx->Device ());
65- batch.offset .SetDevice (ctx->Device ());
66-
67- const xgboost::Entry *data_ptr = batch.data .ConstDevicePointer ();
68- const bst_idx_t *offset_vec = batch.offset .ConstDevicePointer ();
69- size_t batch_size = batch.Size ();
70- if (batch_size > 0 ) {
71- const auto base_rowid = batch.base_rowid ;
72- size_t row_stride = this ->row_stride ;
73- size_t nbins = this ->nbins ;
74- qu->submit ([&](::sycl::handler& cgh) {
75- cgh.parallel_for <>(::sycl::range<1 >(batch_size), [=](::sycl::item<1 > pid) {
66+ for (auto &batch : dmat->GetBatches <SparsePage>()) {
67+ const xgboost::Entry *data_ptr = batch.data .ConstDevicePointer ();
68+ const bst_idx_t *offset_vec = batch.offset .ConstDevicePointer ();
69+ size_t batch_size = batch.Size ();
70+ if (batch_size > 0 ) {
71+ const auto base_rowid = batch.base_rowid ;
72+ event = qu->submit ([&](::sycl::handler& cgh) {
73+ cgh.depends_on (event);
74+ cgh.parallel_for <>(::sycl::range<1 >(batch_size), [=](::sycl::item<1 > pid) {
7675 const size_t i = pid.get_id (0 );
7776 const size_t ibegin = offset_vec[i];
7877 const size_t iend = offset_vec[i + 1 ];
@@ -93,22 +92,23 @@ void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
9392 }
9493 });
9594 });
96- qu-> wait ();
95+ }
9796 }
9897 }
98+ qu->wait ();
9999}
100100
101- void GHistIndexMatrix::ResizeIndex (::sycl::queue* qu, size_t n_index ) {
102- if ((max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint8_t >::max ())) && isDense_ ) {
101+ void GHistIndexMatrix::ResizeIndex (size_t n_index, bool isDense ) {
102+ if ((max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint8_t >::max ())) && isDense ) {
103103 index.SetBinTypeSize (BinTypeSize::kUint8BinsTypeSize );
104- index.Resize (qu, (sizeof (uint8_t )) * n_index);
104+ index.Resize ((sizeof (uint8_t )) * n_index);
105105 } else if ((max_num_bins - 1 > static_cast <int >(std::numeric_limits<uint8_t >::max ()) &&
106- max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint16_t >::max ())) && isDense_ ) {
106+ max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint16_t >::max ())) && isDense ) {
107107 index.SetBinTypeSize (BinTypeSize::kUint16BinsTypeSize );
108- index.Resize (qu, (sizeof (uint16_t )) * n_index);
108+ index.Resize ((sizeof (uint16_t )) * n_index);
109109 } else {
110110 index.SetBinTypeSize (BinTypeSize::kUint32BinsTypeSize );
111- index.Resize (qu, (sizeof (uint32_t )) * n_index);
111+ index.Resize ((sizeof (uint32_t )) * n_index);
112112 }
113113}
114114
@@ -122,50 +122,52 @@ void GHistIndexMatrix::Init(::sycl::queue* qu,
122122 cut.SetDevice (ctx->Device ());
123123
124124 max_num_bins = max_bins;
125- nbins = cut.Ptrs ().back ();
125+ const uint32_t nbins = cut.Ptrs ().back ();
126+ this ->nbins = nbins;
126127
127128 hit_count.SetDevice (ctx->Device ());
128129 hit_count.Resize (nbins, 0 );
129130
131+ this ->p_fmat = dmat;
130132 const bool isDense = dmat->IsDense ();
131133 this ->isDense_ = isDense;
132134
135+ index.setQueue (qu);
136+
133137 row_stride = 0 ;
134138 size_t n_rows = 0 ;
135- if (!isDense ) {
136- for ( const auto & batch : dmat-> GetBatches <SparsePage>()) {
137- const auto & row_offset = batch.offset . ConstHostVector ( );
138- n_rows += batch.Size ( );
139- for ( auto i = 1ull ; i < row_offset. size (); i++) {
140- row_stride = std::max (row_stride, static_cast < size_t >(row_offset[i] - row_offset[i - 1 ]));
141- }
139+ for ( const auto & batch : dmat-> GetBatches <SparsePage>() ) {
140+ const auto & row_offset = batch. offset . ConstHostVector ();
141+ batch.data . SetDevice (ctx-> Device () );
142+ batch.offset . SetDevice (ctx-> Device () );
143+ n_rows += batch. Size ();
144+ for ( auto i = 1ull ; i < row_offset. size (); i++) {
145+ row_stride = std::max (row_stride, static_cast < size_t >(row_offset[i] - row_offset[i - 1 ]));
142146 }
143- } else {
144- row_stride = nfeatures;
145- n_rows = dmat->Info ().num_row_ ;
146147 }
147148
148149 const size_t n_offsets = cut.cut_ptrs_ .Size () - 1 ;
149150 const size_t n_index = n_rows * row_stride;
150- ResizeIndex (qu, n_index );
151+ ResizeIndex (n_index, isDense );
151152
152153 CHECK_GT (cut.cut_values_ .Size (), 0U );
153154
154155 if (isDense) {
155156 BinTypeSize curent_bin_size = index.GetBinTypeSize ();
156157 if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize ) {
157- SetIndexData<uint8_t , true >(qu, ctx, index.data <uint8_t >(), dmat);
158+ SetIndexData<uint8_t , true >(qu, index.data <uint8_t >(), dmat, nbins, row_stride);
159+
158160 } else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize ) {
159- SetIndexData<uint16_t , true >(qu, ctx, index.data <uint16_t >(), dmat);
161+ SetIndexData<uint16_t , true >(qu, index.data <uint16_t >(), dmat, nbins, row_stride );
160162 } else {
161163 CHECK_EQ (curent_bin_size, BinTypeSize::kUint32BinsTypeSize );
162- SetIndexData<uint32_t , true >(qu, ctx, index.data <uint32_t >(), dmat);
164+ SetIndexData<uint32_t , true >(qu, index.data <uint32_t >(), dmat, nbins, row_stride );
163165 }
164166 /* For sparse DMatrix we have to store index of feature for each bin
165167 in index field to chose right offset. So offset is nullptr and index is not reduced */
166168 } else {
167169 sort_buff.Resize (qu, n_rows * row_stride * sizeof (uint32_t ));
168- SetIndexData<uint32_t , false >(qu, ctx, index.data <uint32_t >(), dmat);
170+ SetIndexData<uint32_t , false >(qu, index.data <uint32_t >(), dmat, nbins, row_stride );
169171 }
170172}
171173
0 commit comments