Skip to content

Commit b4adb45

Browse files
committed
Fix type consistency in memory cache and correct test assertions
- Store tensor-converted data in memory cache to match disk cache types - Fix test_automatic_hybrid_caching assertions to account for _InplaceXform (from coderabbit)
1 parent 749a518 commit b4adb45

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

monai/data/dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,14 +435,16 @@ def _cachecheck(self, item_transformed):
435435
if self.in_memory:
436436
self._memory_cache[cache_key] = _item_transformed
437437
return _item_transformed
438+
# Convert to tensor for disk storage (and memory cache consistency)
439+
_item_converted = convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta)
438440
try:
439441
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
440442
# to make the cache more robust to manual killing of parent process
441443
# which may leave partially written cache files in an incomplete state
442444
with tempfile.TemporaryDirectory() as tmpdirname:
443445
temp_hash_file = Path(tmpdirname) / hashfile.name
444446
torch.save(
445-
obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta),
447+
obj=_item_converted,
446448
f=temp_hash_file,
447449
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
448450
pickle_protocol=self.pickle_protocol,
@@ -457,7 +459,7 @@ def _cachecheck(self, item_transformed):
457459
except PermissionError: # project-monai/monai issue #3613
458460
pass
459461
if self.in_memory:
460-
self._memory_cache[cache_key] = _item_transformed
462+
self._memory_cache[cache_key] = _item_converted
461463
return _item_transformed
462464

463465
def _transform(self, index: int):

tests/data/test_persistentdataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,13 @@ def test_automatic_hybrid_caching(self):
291291
# Verify: ALL samples now in RAM again (automatic rebuild from disk)
292292
self.assertEqual(ds2.memory_cache_size, 5)
293293

294-
# Verify: Results are correct
294+
# Verify: Results are correct (transformed by _InplaceXform)
295295
for i, result in enumerate(results):
296-
self.assertEqual(result, [list(range(i))])
296+
if i == 0:
297+
expected = [[1]] # empty list -> append 1
298+
else:
299+
expected = [[np.pi] + list(range(1, i))] # data[0] = 0 + np.pi
300+
self.assertEqual(result, expected)
297301

298302
# === Verify RAM cache provides fast repeated access ===
299303
# Accessing same items again should hit RAM cache (same objects)

0 commit comments

Comments
 (0)