@@ -104,23 +104,29 @@ class EllpackHostCacheStreamImpl {
104
104
105
105
this ->cache_ ->sizes_orig .push_back (page.Impl ()->MemCostBytes ());
106
106
auto orig_ptr = this ->cache_ ->sizes_orig .size () - 1 ;
107
+ CHECK_EQ (this ->cache_ ->pages .size (), this ->cache_ ->on_device .size ());
107
108
108
109
CHECK_LT (orig_ptr, this ->cache_ ->NumBatchesOrig ());
109
110
auto cache_idx = this ->cache_ ->cache_mapping .at (orig_ptr);
110
111
// Wrap up the previous page if this is a new page, or this is the last page.
111
112
auto new_page = cache_idx == this ->cache_ ->pages .size ();
112
-
113
+ // Last page expected from the user.
113
114
auto last_page = (orig_ptr + 1 ) == this ->cache_ ->NumBatchesOrig ();
114
- // No page concatenation is performed. If there's page concatenation, then the number
115
- // of pages in the cache must be smaller than the input number of pages.
116
- bool no_concat = this -> cache_ -> NumBatchesOrig () == this -> cache_ -> buffer_rows . size ();
115
+
116
+ bool const no_concat = this -> cache_ -> NoConcat ();
117
+
117
118
// Whether the page should be cached in device. If true, then we don't need to make a
118
119
// copy during write since the temporary page is already in device when page
119
120
// concatenation is enabled.
120
- bool to_device = this ->cache_ ->prefer_device &&
121
- this ->cache_ ->NumDevicePages () < this ->cache_ ->max_num_device_pages ;
122
-
123
- auto commit_page = [&ctx](EllpackPageImpl const * old_impl) {
121
+ //
122
+ // This applies only to a new cached page. If we are concatenating this page to an
123
+ // existing cached page, then we should respect the existing flag obtained from the
124
+ // first page of the cached page.
125
+ bool to_device_if_new_page =
126
+ this ->cache_ ->prefer_device &&
127
+ this ->cache_ ->NumDevicePages () < this ->cache_ ->max_num_device_pages ;
128
+
129
+ auto commit_host_page = [](EllpackPageImpl const * old_impl) {
124
130
CHECK_EQ (old_impl->gidx_buffer .Resource ()->Type (), common::ResourceHandler::kCudaMalloc );
125
131
auto new_impl = std::make_unique<EllpackPageImpl>();
126
132
new_impl->CopyInfo (old_impl);
@@ -137,7 +143,7 @@ class EllpackHostCacheStreamImpl {
137
143
auto new_impl = std::make_unique<EllpackPageImpl>();
138
144
new_impl->CopyInfo (page.Impl ());
139
145
140
- if (to_device ) {
146
+ if (to_device_if_new_page ) {
141
147
// Copy to device
142
148
new_impl->gidx_buffer = common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(
143
149
page.Impl ()->gidx_buffer .size ());
@@ -151,15 +157,16 @@ class EllpackHostCacheStreamImpl {
151
157
152
158
this ->cache_ ->offsets .push_back (new_impl->n_rows * new_impl->info .row_stride );
153
159
this ->cache_ ->pages .push_back (std::move (new_impl));
160
+ this ->cache_ ->on_device .push_back (to_device_if_new_page);
154
161
return new_page;
155
162
}
156
163
157
164
if (new_page) {
158
165
// No need to copy if it's already in device.
159
- if (!this ->cache_ ->pages .empty () && !to_device ) {
166
+ if (!this ->cache_ ->pages .empty () && !this -> cache_ -> on_device . back () ) {
160
167
// Need to wrap up the previous page.
161
- auto commited = commit_page (this ->cache_ ->pages .back ().get ());
162
- // Replace the previous page with a new page.
168
+ auto commited = commit_host_page (this ->cache_ ->pages .back ().get ());
169
+ // Replace the previous page (on device) with a new page on host .
163
170
this ->cache_ ->pages .back () = std::move (commited);
164
171
}
165
172
// Push a new page
@@ -174,16 +181,18 @@ class EllpackHostCacheStreamImpl {
174
181
auto offset = new_impl->Copy (&ctx, impl, 0 );
175
182
176
183
this ->cache_ ->offsets .push_back (offset);
184
+
177
185
this ->cache_ ->pages .push_back (std::move (new_impl));
186
+ this ->cache_ ->on_device .push_back (to_device_if_new_page);
178
187
} else {
179
188
CHECK (!this ->cache_ ->pages .empty ());
180
189
CHECK_EQ (cache_idx, this ->cache_ ->pages .size () - 1 );
181
190
auto & new_impl = this ->cache_ ->pages .back ();
182
191
auto offset = new_impl->Copy (&ctx, impl, this ->cache_ ->offsets .back ());
183
192
this ->cache_ ->offsets .back () += offset;
184
193
// No need to copy if it's already in device.
185
- if (last_page && !to_device ) {
186
- auto commited = commit_page (this ->cache_ ->pages .back ().get ());
194
+ if (last_page && !this -> cache_ -> on_device . back () ) {
195
+ auto commited = commit_host_page (this ->cache_ ->pages .back ().get ());
187
196
this ->cache_ ->pages .back () = std::move (commited);
188
197
}
189
198
}
0 commit comments