|
24 | 24 | import org.apache.sysds.runtime.io.IOUtilFunctions; |
25 | 25 | import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
26 | 26 | import org.apache.sysds.runtime.matrix.data.MatrixIndexes; |
27 | | -import org.apache.sysds.runtime.util.FastBufferedDataInputStream; |
28 | 27 | import org.apache.sysds.runtime.util.FastBufferedDataOutputStream; |
29 | 28 | import org.apache.sysds.runtime.util.LocalFileUtils; |
30 | 29 |
|
31 | 30 | import java.io.DataInputStream; |
32 | 31 | import java.io.File; |
33 | | -import java.io.FileInputStream; |
34 | | -import java.io.FileNotFoundException; |
35 | 32 | import java.io.FileOutputStream; |
36 | 33 | import java.io.IOException; |
37 | 34 | import java.io.RandomAccessFile; |
38 | 35 | import java.nio.channels.Channels; |
39 | 36 | import java.util.ArrayList; |
40 | 37 | import java.util.Iterator; |
41 | 38 | import java.util.LinkedHashMap; |
42 | | -import java.util.LinkedList; |
43 | 39 | import java.util.List; |
44 | 40 | import java.util.Map; |
45 | 41 | import java.util.Set; |
@@ -230,7 +226,7 @@ public static IndexedMatrixValue get(long streamId, int blockId) { |
230 | 226 | // 1. wait for eviction to complete |
231 | 227 | while (imv.state == BlockState.EVICTING) { |
232 | 228 | try { |
233 | | - imv.stateUpdate.wait(); |
| 229 | + imv.stateUpdate.await(); |
234 | 230 | } catch (InterruptedException e) { |
235 | 231 |
|
236 | 232 | throw new DMLRuntimeException(e); |
@@ -275,19 +271,20 @@ private static void evict() { |
275 | 271 | Map.Entry<String, BlockEntry> e = iter.next(); |
276 | 272 | BlockEntry entry = e.getValue(); |
277 | 273 |
|
278 | | - entry.lock.lock(); |
279 | | - try { |
280 | | - if (entry.state == BlockState.HOT) { |
281 | | - entry.state = BlockState.EVICTING; |
282 | | - candidates.add(e); |
283 | | - totalFreedSize += entry.size; |
284 | | - |
285 | | - //remove current iterator entry |
286 | | - iter.remove(); |
| 274 | + if (entry.lock.tryLock()) { |
| 275 | + try { |
| 276 | + if (entry.state == BlockState.HOT) { |
| 277 | + entry.state = BlockState.EVICTING; |
| 278 | + candidates.add(e); |
| 279 | + totalFreedSize += entry.size; |
| 280 | + |
| 281 | + //remove current iterator entry |
| 282 | +// iter.remove(); |
| 283 | + } |
| 284 | + } finally { |
| 285 | + entry.lock.unlock(); |
287 | 286 | } |
288 | | - } finally { |
289 | | - entry.lock.unlock(); |
290 | | - } |
| 287 | + } // if tryLock() fails, it means a thread is loading/reading this block. we shall skip it. |
291 | 288 | } |
292 | 289 |
|
293 | 290 | } |
@@ -316,57 +313,42 @@ private static void evict() { |
316 | 313 | fos = new FileOutputStream(filename); |
317 | 314 | dos = new FastBufferedDataOutputStream(fos); |
318 | 315 |
|
319 | | - int pos = 0; |
320 | | - while(_size.get() > _limit && pos++ < _cache.size()) { |
321 | | - System.err.println("BUFFER: "+_size+"/"+_limit+" size="+_cache.size()); |
322 | | - |
323 | | - // loop over the list of blocks we collected |
324 | | - for (Map.Entry<String,BlockEntry> tmp : candidates) { |
325 | | - BlockEntry entry = tmp.getValue(); |
326 | | - |
327 | | - // Skip if block is null. i.e, COLD |
328 | | -// if (entry.value.getValue() == null) { |
329 | | -// synchronized (_cacheLock) { |
330 | | -// _cache.put(tmp.getKey(), entry); |
331 | | -// } |
332 | | -// continue; |
333 | | -// } |
334 | | - |
335 | | - // 1. get the current file position. this is the offset. |
336 | | - // flush any buffered data to the file |
337 | | - dos.flush(); |
338 | | - long offset = fos.getChannel().position(); |
339 | | - |
340 | | - // 2. write indexes and block |
341 | | - entry.value.getIndexes().write(dos); // write Indexes |
342 | | - entry.value.getValue().write(dos); |
343 | | - |
344 | | - // 3. create the spillLocation |
345 | | - spillLocation sloc = new spillLocation(partitionId, offset); |
346 | | - _spillLocations.put(tmp.getKey(), sloc); |
347 | | - |
348 | | - // 4. track file for cleanup |
349 | | - _streamPartitions |
350 | | - .computeIfAbsent(entry.streamId, k -> ConcurrentHashMap.newKeySet()) |
351 | | - .add(filename); |
352 | | - |
353 | | - // account for memory |
354 | | - long freedSize = estimateSerializedSize((MatrixBlock) tmp.getValue().value.getValue()); |
355 | | - totalFreedSize += freedSize; |
356 | | - |
357 | | - // 5. change state to COLD |
358 | | - entry.lock.lock(); |
359 | | - try { |
360 | | - entry.value.setValue(null); |
361 | | - entry.state = BlockState.COLD; // set state to cold, since writing to disk |
362 | | - entry.stateUpdate.signalAll(); // wake up any "get()" threads |
363 | | - } finally { |
364 | | - entry.lock.unlock(); |
365 | | - } |
366 | 316 |
|
367 | | - synchronized (_cacheLock) { |
368 | | - _cache.put(tmp.getKey(), entry); // add last semantic |
369 | | - } |
| 317 | + // loop over the list of blocks we collected |
| 318 | + for (Map.Entry<String,BlockEntry> tmp : candidates) { |
| 319 | + BlockEntry entry = tmp.getValue(); |
| 320 | + |
| 321 | + // 1. get the current file position. this is the offset. |
| 322 | + // flush any buffered data to the file |
| 323 | + dos.flush(); |
| 324 | + long offset = fos.getChannel().position(); |
| 325 | + |
| 326 | + // 2. write indexes and block |
| 327 | + entry.value.getIndexes().write(dos); // write Indexes |
| 328 | + entry.value.getValue().write(dos); |
| 329 | + System.out.println("written, partition id: " + _partitions.get(partitionId) + ", offset: " + offset); |
| 330 | + |
| 331 | + // 3. create the spillLocation |
| 332 | + spillLocation sloc = new spillLocation(partitionId, offset); |
| 333 | + _spillLocations.put(tmp.getKey(), sloc); |
| 334 | + |
| 335 | + // 4. track file for cleanup |
| 336 | + _streamPartitions |
| 337 | + .computeIfAbsent(entry.streamId, k -> ConcurrentHashMap.newKeySet()) |
| 338 | + .add(filename); |
| 339 | + |
| 340 | + // 5. change state to COLD |
| 341 | + entry.lock.lock(); |
| 342 | + try { |
| 343 | + entry.value = null; // only release ref, don't mutate object |
| 344 | + entry.state = BlockState.COLD; // set state to cold, since writing to disk |
| 345 | + entry.stateUpdate.signalAll(); // wake up any "get()" threads |
| 346 | + } finally { |
| 347 | + entry.lock.unlock(); |
| 348 | + } |
| 349 | + |
| 350 | + synchronized (_cacheLock) { |
| 351 | + _cache.put(tmp.getKey(), entry); // add last semantic |
370 | 352 | } |
371 | 353 | } |
372 | 354 | } |
@@ -425,18 +407,29 @@ private static IndexedMatrixValue loadFromDisk(long streamId, int blockId) { |
425 | 407 | imvCacheEntry = _cache.get(key); |
426 | 408 | } |
427 | 409 |
|
| 410 | + // 2. Check if it's null (the bug you helped fix before) |
| 411 | + if(imvCacheEntry == null) { |
| 412 | + throw new DMLRuntimeException("Block entry " + key + " was not in cache during load."); |
| 413 | + } |
| 414 | + |
428 | 415 | imvCacheEntry.lock.lock(); |
429 | 416 | try { |
430 | 417 | if (imvCacheEntry.state == BlockState.COLD) { |
431 | 418 | imvCacheEntry.value = new IndexedMatrixValue(ix, mb); |
432 | 419 | imvCacheEntry.state = BlockState.HOT; |
433 | 420 | _size.addAndGet(imvCacheEntry.size); |
| 421 | + |
| 422 | + synchronized (_cacheLock) { |
| 423 | + _cache.remove(key); |
| 424 | + _cache.put(key, imvCacheEntry); |
| 425 | + } |
434 | 426 | } |
| 427 | + |
| 428 | +// evict(); // when we add the block, we shall check for limit. |
435 | 429 | } finally { |
436 | 430 | imvCacheEntry.lock.unlock(); |
437 | 431 | } |
438 | 432 |
|
439 | | - evict(); // when we add the block, we shall check for limit. |
440 | 433 | return imvCacheEntry.value; |
441 | 434 | } |
442 | 435 |
|
|
0 commit comments