@@ -50,28 +50,29 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
5050
5151template <typename BinIdxType, bool isDense>
5252void GHistIndexMatrix::SetIndexData (::sycl::queue* qu,
53+ Context const * ctx,
5354 BinIdxType* index_data,
54- DMatrix *dmat,
55- size_t nbins,
56- size_t row_stride) {
55+ DMatrix *dmat) {
5756 if (nbins == 0 ) return ;
5857 const bst_float* cut_values = cut.cut_values_ .ConstDevicePointer ();
5958 const uint32_t * cut_ptrs = cut.cut_ptrs_ .ConstDevicePointer ();
6059 size_t * hit_count_ptr = hit_count.DevicePointer ();
6160
6261 BinIdxType* sort_data = reinterpret_cast <BinIdxType*>(sort_buff.Data ());
6362
64- ::sycl::event event;
6563 for (auto &batch : dmat->GetBatches <SparsePage>()) {
66- for (auto &batch : dmat->GetBatches <SparsePage>()) {
67- const xgboost::Entry *data_ptr = batch.data .ConstDevicePointer ();
68- const bst_idx_t *offset_vec = batch.offset .ConstDevicePointer ();
69- size_t batch_size = batch.Size ();
70- if (batch_size > 0 ) {
71- const auto base_rowid = batch.base_rowid ;
72- event = qu->submit ([&](::sycl::handler& cgh) {
73- cgh.depends_on (event);
74- cgh.parallel_for <>(::sycl::range<1 >(batch_size), [=](::sycl::item<1 > pid) {
64+ batch.data .SetDevice (ctx->Device ());
65+ batch.offset .SetDevice (ctx->Device ());
66+
67+ const xgboost::Entry *data_ptr = batch.data .ConstDevicePointer ();
68+ const bst_idx_t *offset_vec = batch.offset .ConstDevicePointer ();
69+ size_t batch_size = batch.Size ();
70+ if (batch_size > 0 ) {
71+ const auto base_rowid = batch.base_rowid ;
72+ size_t row_stride = this ->row_stride ;
73+ size_t nbins = this ->nbins ;
74+ qu->submit ([&](::sycl::handler& cgh) {
75+ cgh.parallel_for <>(::sycl::range<1 >(batch_size), [=](::sycl::item<1 > pid) {
7576 const size_t i = pid.get_id (0 );
7677 const size_t ibegin = offset_vec[i];
7778 const size_t iend = offset_vec[i + 1 ];
@@ -92,23 +93,22 @@ void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
9293 }
9394 });
9495 });
95- }
96+ qu-> wait ();
9697 }
9798 }
98- qu->wait ();
9999}
100100
101- void GHistIndexMatrix::ResizeIndex (size_t n_index, bool isDense ) {
102- if ((max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint8_t >::max ())) && isDense ) {
101+ void GHistIndexMatrix::ResizeIndex (::sycl::queue* qu, size_t n_index ) {
102+ if ((max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint8_t >::max ())) && isDense_ ) {
103103 index.SetBinTypeSize (BinTypeSize::kUint8BinsTypeSize );
104- index.Resize ((sizeof (uint8_t )) * n_index);
104+ index.Resize (qu, (sizeof (uint8_t )) * n_index);
105105 } else if ((max_num_bins - 1 > static_cast <int >(std::numeric_limits<uint8_t >::max ()) &&
106- max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint16_t >::max ())) && isDense ) {
106+ max_num_bins - 1 <= static_cast <int >(std::numeric_limits<uint16_t >::max ())) && isDense_ ) {
107107 index.SetBinTypeSize (BinTypeSize::kUint16BinsTypeSize );
108- index.Resize ((sizeof (uint16_t )) * n_index);
108+ index.Resize (qu, (sizeof (uint16_t )) * n_index);
109109 } else {
110110 index.SetBinTypeSize (BinTypeSize::kUint32BinsTypeSize );
111- index.Resize ((sizeof (uint32_t )) * n_index);
111+ index.Resize (qu, (sizeof (uint32_t )) * n_index);
112112 }
113113}
114114
@@ -122,52 +122,50 @@ void GHistIndexMatrix::Init(::sycl::queue* qu,
122122 cut.SetDevice (ctx->Device ());
123123
124124 max_num_bins = max_bins;
125- const uint32_t nbins = cut.Ptrs ().back ();
126- this ->nbins = nbins;
125+ nbins = cut.Ptrs ().back ();
127126
128127 hit_count.SetDevice (ctx->Device ());
129128 hit_count.Resize (nbins, 0 );
130129
131- this ->p_fmat = dmat;
132130 const bool isDense = dmat->IsDense ();
133131 this ->isDense_ = isDense;
134132
135- index.setQueue (qu);
136-
137133 row_stride = 0 ;
138134 size_t n_rows = 0 ;
139- for ( const auto & batch : dmat-> GetBatches <SparsePage>() ) {
140- const auto & row_offset = batch. offset . ConstHostVector ();
141- batch.data . SetDevice (ctx-> Device () );
142- batch.offset . SetDevice (ctx-> Device () );
143- n_rows += batch. Size ();
144- for ( auto i = 1ull ; i < row_offset. size (); i++) {
145- row_stride = std::max (row_stride, static_cast < size_t >(row_offset[i] - row_offset[i - 1 ]));
135+ if (!isDense ) {
136+ for ( const auto & batch : dmat-> GetBatches <SparsePage>()) {
137+ const auto & row_offset = batch.offset . ConstHostVector ( );
138+ n_rows += batch.Size ( );
139+ for ( auto i = 1ull ; i < row_offset. size (); i++) {
140+ row_stride = std::max (row_stride, static_cast < size_t >(row_offset[i] - row_offset[i - 1 ]));
141+ }
146142 }
143+ } else {
144+ row_stride = nfeatures;
145+ n_rows = dmat->Info ().num_row_ ;
147146 }
148147
149148 const size_t n_offsets = cut.cut_ptrs_ .Size () - 1 ;
150149 const size_t n_index = n_rows * row_stride;
151- ResizeIndex (n_index, isDense );
150+ ResizeIndex (qu, n_index );
152151
153152 CHECK_GT (cut.cut_values_ .Size (), 0U );
154153
155154 if (isDense) {
156155 BinTypeSize curent_bin_size = index.GetBinTypeSize ();
157156 if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize ) {
158- SetIndexData<uint8_t , true >(qu, index.data <uint8_t >(), dmat, nbins, row_stride);
159-
157+ SetIndexData<uint8_t , true >(qu, ctx, index.data <uint8_t >(), dmat);
160158 } else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize ) {
161- SetIndexData<uint16_t , true >(qu, index.data <uint16_t >(), dmat, nbins, row_stride );
159+ SetIndexData<uint16_t , true >(qu, ctx, index.data <uint16_t >(), dmat);
162160 } else {
163161 CHECK_EQ (curent_bin_size, BinTypeSize::kUint32BinsTypeSize );
164- SetIndexData<uint32_t , true >(qu, index.data <uint32_t >(), dmat, nbins, row_stride );
162+ SetIndexData<uint32_t , true >(qu, ctx, index.data <uint32_t >(), dmat);
165163 }
166164 /* For sparse DMatrix we have to store index of feature for each bin
167165 in index field to chose right offset. So offset is nullptr and index is not reduced */
168166 } else {
169167 sort_buff.Resize (qu, n_rows * row_stride * sizeof (uint32_t ));
170- SetIndexData<uint32_t , false >(qu, index.data <uint32_t >(), dmat, nbins, row_stride );
168+ SetIndexData<uint32_t , false >(qu, ctx, index.data <uint32_t >(), dmat);
171169 }
172170}
173171
0 commit comments