@@ -104,23 +104,29 @@ class EllpackHostCacheStreamImpl {
104104
105105 this ->cache_ ->sizes_orig .push_back (page.Impl ()->MemCostBytes ());
106106 auto orig_ptr = this ->cache_ ->sizes_orig .size () - 1 ;
107+ CHECK_EQ (this ->cache_ ->pages .size (), this ->cache_ ->on_device .size ());
107108
108109 CHECK_LT (orig_ptr, this ->cache_ ->NumBatchesOrig ());
109110 auto cache_idx = this ->cache_ ->cache_mapping .at (orig_ptr);
110111 // Wrap up the previous page if this is a new page, or this is the last page.
111112 auto new_page = cache_idx == this ->cache_ ->pages .size ();
112-
113+ // Last page expected from the user.
113114 auto last_page = (orig_ptr + 1 ) == this ->cache_ ->NumBatchesOrig ();
114- // No page concatenation is performed. If there's page concatenation, then the number
115- // of pages in the cache must be smaller than the input number of pages.
116- bool no_concat = this -> cache_ -> NumBatchesOrig () == this -> cache_ -> buffer_rows . size ();
115+
116+ bool const no_concat = this -> cache_ -> NoConcat ();
117+
117118 // Whether the page should be cached in device. If true, then we don't need to make a
118119 // copy during write since the temporary page is already in device when page
119120 // concatenation is enabled.
120- bool to_device = this ->cache_ ->prefer_device &&
121- this ->cache_ ->NumDevicePages () < this ->cache_ ->max_num_device_pages ;
122-
123- auto commit_page = [&ctx](EllpackPageImpl const * old_impl) {
121+ //
122+ // This applies only to a new cached page. If we are concatenating this page to an
123+ // existing cached page, then we should respect the existing flag obtained from the
124+ // first page of the cached page.
125+ bool to_device_if_new_page =
126+ this ->cache_ ->prefer_device &&
127+ this ->cache_ ->NumDevicePages () < this ->cache_ ->max_num_device_pages ;
128+
129+ auto commit_host_page = [](EllpackPageImpl const * old_impl) {
124130 CHECK_EQ (old_impl->gidx_buffer .Resource ()->Type (), common::ResourceHandler::kCudaMalloc );
125131 auto new_impl = std::make_unique<EllpackPageImpl>();
126132 new_impl->CopyInfo (old_impl);
@@ -137,7 +143,7 @@ class EllpackHostCacheStreamImpl {
137143 auto new_impl = std::make_unique<EllpackPageImpl>();
138144 new_impl->CopyInfo (page.Impl ());
139145
140- if (to_device ) {
146+ if (to_device_if_new_page ) {
141147 // Copy to device
142148 new_impl->gidx_buffer = common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(
143149 page.Impl ()->gidx_buffer .size ());
@@ -151,15 +157,16 @@ class EllpackHostCacheStreamImpl {
151157
152158 this ->cache_ ->offsets .push_back (new_impl->n_rows * new_impl->info .row_stride );
153159 this ->cache_ ->pages .push_back (std::move (new_impl));
160+ this ->cache_ ->on_device .push_back (to_device_if_new_page);
154161 return new_page;
155162 }
156163
157164 if (new_page) {
158165 // No need to copy if it's already in device.
159- if (!this ->cache_ ->pages .empty () && !to_device ) {
166+ if (!this ->cache_ ->pages .empty () && !this -> cache_ -> on_device . back () ) {
160167 // Need to wrap up the previous page.
161- auto commited = commit_page (this ->cache_ ->pages .back ().get ());
162- // Replace the previous page with a new page.
168+ auto commited = commit_host_page (this ->cache_ ->pages .back ().get ());
169+ // Replace the previous page (on device) with a new page on host .
163170 this ->cache_ ->pages .back () = std::move (commited);
164171 }
165172 // Push a new page
@@ -174,16 +181,18 @@ class EllpackHostCacheStreamImpl {
174181 auto offset = new_impl->Copy (&ctx, impl, 0 );
175182
176183 this ->cache_ ->offsets .push_back (offset);
184+
177185 this ->cache_ ->pages .push_back (std::move (new_impl));
186+ this ->cache_ ->on_device .push_back (to_device_if_new_page);
178187 } else {
179188 CHECK (!this ->cache_ ->pages .empty ());
180189 CHECK_EQ (cache_idx, this ->cache_ ->pages .size () - 1 );
181190 auto & new_impl = this ->cache_ ->pages .back ();
182191 auto offset = new_impl->Copy (&ctx, impl, this ->cache_ ->offsets .back ());
183192 this ->cache_ ->offsets .back () += offset;
184193 // No need to copy if it's already in device.
185- if (last_page && !to_device ) {
186- auto commited = commit_page (this ->cache_ ->pages .back ().get ());
194+ if (last_page && !this -> cache_ -> on_device . back () ) {
195+ auto commited = commit_host_page (this ->cache_ ->pages .back ().get ());
187196 this ->cache_ ->pages .back () = std::move (commited);
188197 }
189198 }
0 commit comments