Skip to content

Conversation

@StijnvWijn
Copy link
Contributor

Fixes #1305.

Description

A slowdown while sampling patches that are much smaller than the image in Torchio >= 0.20.4

Checklist

  • I have read the CONTRIBUTING docs and have a developer setup ready
  • Changes are
    • Non-breaking (would not break existing functionality)
    • Breaking (would cause existing functionality to change)
  • Tests added or modified to cover the changes
  • In-line docstrings updated
  • Documentation updated
  • This pull request is ready to be reviewed

@fepegar fepegar requested a review from nicoloesch June 4, 2025 21:32
@fepegar
Copy link
Member

fepegar commented Jun 4, 2025

@nicoloesch could you please comment on the general design? I can then fix formatting etc.

Copy link
Contributor

@nicoloesch nicoloesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like the changes @StijnvWijn and this is exactly how I would have done it. Some changes are purely stylistic, others require some feedback from @fepegar about formatting and docstring description to make it as clear as possible.
Most of my comments are suggestions instead of requirements.

@StijnvWijn
Copy link
Contributor Author

StijnvWijn commented Jun 5, 2025

So I was building the docs and I encountered an issue I am not sure how to solve:
The Colin27 dataset is a subclass of Subject, but does not allow the same init parameters as the Subject class and has some private attributes that we can't access with .items(). I tried copying using the __dict__ attribute as suggested here, but I can't seem to initialize a Subject without arguments and i need to initialize the Colin27 without arguments, so I am unsure how to generalize the copy behaviour so that every class gets handled correctly.

I fixed it, but now the apply_transform function has some less beautiful code to make a new subject, copy it and all its attributes, it uses the dict and class attributes, but I added some comments, so hopefully it is clear why.

@StijnvWijn
Copy link
Contributor Author

@nicoloesch Do you have any comments on the new copy mechanism?

Copy link
Contributor

@nicoloesch nicoloesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All my comments have been addressed! Thank you @StijnvWijn

@fepegar fepegar changed the title 1305 Patch sampler speedup Stop copying whole image before cropping Jun 15, 2025
@fepegar
Copy link
Member

fepegar commented Jun 15, 2025

@allcontributors please add @nicoloesch for design, maintenance, question, review

@allcontributors
Copy link
Contributor

@fepegar

I've put up a pull request to add @nicoloesch! 🎉

@fepegar fepegar merged commit 5f28b56 into TorchIO-project:main Jun 15, 2025
21 of 22 checks passed
@fepegar
Copy link
Member

fepegar commented Jun 15, 2025

@allcontributors please add @StijnvWijn for code

@allcontributors
Copy link
Contributor

@fepegar

I've put up a pull request to add @StijnvWijn! 🎉

@fepegar
Copy link
Member

fepegar commented Jun 15, 2025

Thank you both for your contribution!

@fepegar
Copy link
Member

fepegar commented Jul 3, 2025

Unfortunately, these changes are giving memory trouble. The script below shows that, from v0.20.10, references to the uncropped image are still somewhere and memory just grows and grows.

import gc
import os
import psutil


num_subjects = 1


process = psutil.Process(os.getpid())


def print_referrers_recursively(obj, depth=2):
    """Recursively print referrers of an object."""
    if depth == 0:
        return
    referrers = gc.get_referrers(obj)
    print(f"Object {id(obj)} has {len(referrers)} referrers")
    for ref in referrers:
        print(f"Referrer type and ID: {type(ref)}, {id(ref)}")
        if isinstance(ref, dict):
            print(f"  Dict keys: {list(ref.keys())[:10]}")
        elif isinstance(ref, (list, tuple)):
            print(f"  Sequence length: {len(ref)}")
        elif isinstance(ref, type):
            print(f"  Class: {ref.__name__}")
        else:
            print(f"  Other referrer: {type(ref)} (repr suppressed)")
        print()
        # Recursively check referrers of the referrer
        print_referrers_recursively(ref, depth=depth - 1)
        print()


def print_memory_usage():
    """Print the current memory usage of the process."""
    mem = process.memory_info().rss / 1024**2  # Convert bytes to MB
    print(f"\nMemory usage: {mem:.2f} MB")

    import torch

    for obj in gc.get_objects():
        is_tensor = torch.is_tensor(obj)
        if is_tensor:
            if obj.shape[-1] == 16:
                continue
            print("-" * 80)
            tensor = obj if is_tensor else obj.data
            tensor_id = id(tensor)
            print(
                f"Type: {type(tensor)}, "
                f"Shape: {tuple(tensor.shape)}, "
                f"Dtype: {tensor.dtype}, "
                f"ID: {tensor_id}"
            )
            print_referrers_recursively(obj)
            print("-" * 80)
            import sys

            sys.exit(0)


print("\nChecking memory usage before importing...")
print_memory_usage()

print("\nImporting...")
import torchio as tio

print_memory_usage()

print("\nCreating subjects list...")
subjects = [tio.datasets.Colin27(2008) for _ in range(num_subjects)]
transform = tio.CropOrPad(16)
print_memory_usage()

print("\nInstantiating SubjectsDataset...")
dataset = tio.SubjectsDataset(subjects, transform=transform)
print_memory_usage()

print("\nLoading...")
for subject in dataset:
    print_memory_usage()

@fepegar
Copy link
Member

fepegar commented Jul 3, 2025

I'm going to submit a hotfix reverting this in a new version, but it's a very nice feature so I hope it can be fixed :)

@StijnvWijn
Copy link
Contributor Author

Hmm that is very sad. I am not sure where the issue would occur, because all attributes seem to be deepcopied from the original input. Do you have any idea where to start?

@fepegar
Copy link
Member

fepegar commented Jul 12, 2025

Hmm that is very sad.

I know, and I'm sorry I had to revert this. I'm sure we'll be able to find a solution! I'm happy to refactor the design if needed.

I am not sure where the issue would occur, because all attributes seem to be deepcopied from the original input. Do you have any idea where to start?

I would start with the snippet I shared above.

@StijnvWijn
Copy link
Contributor Author

Thanks for you quick response!

I am not sure I can reproduce your issue. In your snippet, indeed the memory usage increases slightly between the instantiation of the dataset and the creation of the cropped patches, but IMO this is expected, because we make a copy of the subject with all the old attributes and put a copy of the cropped area of the original image into the subject.

Also if I run the subject generation loop multiple times, I see some minor fluctuations in memory usage, but they go both up and down between loops, so not the behaviour I would expect with a memory leak.

I am running it on a server with quite a bit of RAM, just like I described in my original issue, so that might also have an impact on my results.

@StijnvWijn
Copy link
Contributor Author

@fepegar

I would not mind to work on this, but I can't test whether my implementation works currently. Shall I make a new issue or PR to discuss this issue or continue it here?

@nicoloesch
Copy link
Contributor

I may also be able to assist but I do have currently very limited time as I am in the final stages to submit my PhD. I may have some time this week and will start with the snippet you, @fepegar, provided and whether I can reproduce the memory leak!

@StijnvWijn
Copy link
Contributor Author

Ah awesome, that is great news for you. Good luck!

It is not too urgent as we are currently just using an older version of tio, so you don't need to worry too much about it :)

@fepegar
Copy link
Member

fepegar commented Aug 1, 2025

Thanks for you quick response!

I am not sure I can reproduce your issue. In your snippet, indeed the memory usage increases slightly between the instantiation of the dataset and the creation of the cropped patches, but IMO this is expected, because we make a copy of the subject with all the old attributes and put a copy of the cropped area of the original image into the subject.

Also if I run the subject generation loop multiple times, I see some minor fluctuations in memory usage, but they go both up and down between loops, so not the behaviour I would expect with a memory leak.

I am running it on a server with quite a bit of RAM, just like I described in my original issue, so that might also have an impact on my results.

Which TorchIO version and OS are you using?

@StijnvWijn
Copy link
Contributor Author

Thanks for your response!

So my environment has the following:

Platform:   Linux-5.15.0-143-generic-x86_64-with-glibc2.35
TorchIO:    0.20.3
PyTorch:    2.2.0+cpu
SimpleITK:  2.5.0 (ITK 5.4)
NumPy:      1.26.4
Python:     3.11.0rc1 (main, Aug 12 2022, 10:02:14) [GCC 11.2.0]

And if I run the following example (adapted from your script):

import gc
import os
import psutil


num_subjects = 1


process = psutil.Process(os.getpid())


def print_referrers_recursively(obj, depth=2):
    """Recursively print referrers of an object."""
    if depth == 0:
        return
    referrers = gc.get_referrers(obj)
    print(f"Object {id(obj)}, {type(obj)} has {len(referrers)} referrers")
    for ref in referrers:
        print(f"Referrer type and ID: {type(ref)}, {id(ref)}")
        if isinstance(ref, dict):
            print(f"  Dict keys: {list(ref.keys())[:10]}")
        elif isinstance(ref, (list, tuple)):
            print(f"  Sequence length: {len(ref)}")
        elif isinstance(ref, type):
            print(f"  Class: {ref.__name__}")
        else:
            print(f"  Other referrer: {type(ref)} (repr suppressed)")
        print()
        # Recursively check referrers of the referrer
        print_referrers_recursively(ref, depth=depth - 1)
        print()


def print_memory_usage():
    """Print the current memory usage of the process."""
    mem = process.memory_info().rss / 1024**2  # Convert bytes to MB
    print(f"\nMemory usage: {mem:.2f} MB")

    import torch

    for obj in gc.get_objects():
        is_tensor = torch.is_tensor(obj)
        if is_tensor:
            if obj.shape[-1] == 16:
                continue
            print("-" * 80)
            tensor = obj if is_tensor else obj.data
            tensor_id = id(tensor)
            print(
                f"Type: {type(tensor)}, "
                f"Shape: {tuple(tensor.shape)}, "
                f"Dtype: {tensor.dtype}, "
                f"ID: {tensor_id}"
            )
            print_referrers_recursively(obj)
            print("-" * 80)
            # import sys

            # sys.exit(0)


print("\nChecking memory usage before importing...")
print_memory_usage()

print("\nImporting...")
import torchio as tio

print_memory_usage()

print("\nCreating subjects list...")
subjects = [tio.datasets.Colin27(2008) for _ in range(num_subjects)]
transform = tio.CropOrPad(16)
print_memory_usage()

print("\nInstantiating SubjectsDataset...")
dataset = tio.SubjectsDataset(subjects, transform=transform)
print_memory_usage()

print("\nLoading...")
for _ in range(5):
    print("testing memory usage")
    for subject in dataset:
        print_memory_usage()

for subject in dataset:
    print_memory_usage()
    print(f"Subject ID: {id(subject)}")
    print(f"Subject keys: {list(subject.keys())}")
    for key, value in subject.items():
        print(f"  {key}: {type(value)}, Shape: {value.shape if hasattr(value, 'shape') else 'N/A'}")
    print_referrers_recursively(subject['t1'], depth=3)

With the following tio versions, I get:
Torchio 0.20.3 (No memory leak expected):

Checking memory usage before importing...

Memory usage: 12.03 MB

Importing...

Memory usage: 325.64 MB

Creating subjects list...

Memory usage: 326.93 MB

Instantiating SubjectsDataset...

Memory usage: 326.93 MB

Loading...
testing memory usage

Memory usage: 332.82 MB
testing memory usage

Memory usage: 334.54 MB
testing memory usage

Memory usage: 334.16 MB
testing memory usage

Memory usage: 335.05 MB
testing memory usage

Memory usage: 335.17 MB

Memory usage: 334.27 MB
Subject ID: 139742629258896
Subject keys: ['t1', 't2', 'pd', 'cls']
  t1: <class 'torchio.data.image.ScalarImage'>, Shape: (1, 16, 16, 16)
  t2: <class 'torchio.data.image.ScalarImage'>, Shape: (1, 16, 16, 16)
  pd: <class 'torchio.data.image.ScalarImage'>, Shape: (1, 16, 16, 16)
  cls: <class 'torchio.data.image.LabelMap'>, Shape: (1, 16, 16, 16)
Object 139742629251600, <class 'torchio.data.image.ScalarImage'> has 2 referrers
Referrer type and ID: <class 'torchio.data.subject.Subject'>, 139742629258896
  Dict keys: ['t1', 't2', 'pd', 'cls']

Object 139742629258896, <class 'torchio.data.subject.Subject'> has 2 referrers
Referrer type and ID: <class 'list'>, 139742629315328
  Sequence length: 2

Object 139742629315328, <class 'list'> has 2 referrers
Referrer type and ID: <class 'list_iterator'>, 139742629209584
  Other referrer: <class 'list_iterator'> (repr suppressed)


Referrer type and ID: <class 'list'>, 139742629318080
  Sequence length: 2



Referrer type and ID: <class 'dict'>, 139746139912960
  Dict keys: ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'gc']

Object 139746139912960, <class 'dict'> has 5 referrers
Referrer type and ID: <class 'list'>, 139742629318080
  Sequence length: 2


Referrer type and ID: <class 'module'>, 139746139916304
  Other referrer: <class 'module'> (repr suppressed)


Referrer type and ID: <class 'function'>, 139746139775488
  Other referrer: <class 'function'> (repr suppressed)


Referrer type and ID: <class 'function'>, 139746137540544
  Other referrer: <class 'function'> (repr suppressed)


Referrer type and ID: <class 'function'>, 139746138048768
  Other referrer: <class 'function'> (repr suppressed)

Referrer type and ID: <class 'dict'>, 139742637074816
  Dict keys: ['t1', 't2', 'pd', 'cls', 'applied_transforms']

Object 139742637074816, <class 'dict'> has 2 referrers
Referrer type and ID: <class 'torchio.data.subject.Subject'>, 139742629258896
  Dict keys: ['t1', 't2', 'pd', 'cls']

Object 139742629258896, <class 'torchio.data.subject.Subject'> has 3 referrers
Referrer type and ID: <class 'list'>, 139742629315328
  Sequence length: 2


Referrer type and ID: <class 'list'>, 139744021918784
  Sequence length: 2


Referrer type and ID: <class 'dict'>, 139746139912960
  Dict keys: ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'gc']

Referrer type and ID: <class 'list'>, 139742629315328
  Sequence length: 2

Object 139742629315328, <class 'list'> has 2 referrers
Referrer type and ID: <class 'list_iterator'>, 139742629209584
  Other referrer: <class 'list_iterator'> (repr suppressed)


Referrer type and ID: <class 'list'>, 139744021918784
  Sequence length: 2

0.20.10 (Version with this MR)

Checking memory usage before importing...

Memory usage: 12.03 MB
/home/stijn/venvs/rndeep_dev3.11/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py:359: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn(

Importing...

Memory usage: 325.46 MB
/home/stijn/venvs/rndeep_dev3.11/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py:359: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn(

Creating subjects list...

Memory usage: 326.75 MB

Instantiating SubjectsDataset...

Memory usage: 326.75 MB

Loading...
testing memory usage

Memory usage: 334.41 MB
testing memory usage

Memory usage: 334.92 MB
testing memory usage

Memory usage: 335.52 MB
testing memory usage

Memory usage: 335.66 MB
testing memory usage

Memory usage: 336.83 MB

Memory usage: 337.18 MB
Subject ID: 140106867906544
Subject keys: ['t1', 't2', 'pd', 'cls']
  t1: <class 'torchio.data.image.ScalarImage'>, Shape: (1, 16, 16, 16)
  t2: <class 'torchio.data.image.ScalarImage'>, Shape: (1, 16, 16, 16)
  pd: <class 'torchio.data.image.ScalarImage'>, Shape: (1, 16, 16, 16)
  cls: <class 'torchio.data.image.LabelMap'>, Shape: (1, 16, 16, 16)
Object 140106867905392, <class 'torchio.data.image.ScalarImage'> has 2 referrers
Referrer type and ID: <class 'torchio.datasets.mni.colin.Colin27'>, 140106867906544
  Dict keys: ['t1', 't2', 'pd', 'cls']

Object 140106867906544, <class 'torchio.datasets.mni.colin.Colin27'> has 2 referrers
Referrer type and ID: <class 'list'>, 140106867953856
  Sequence length: 2

Object 140106867953856, <class 'list'> has 2 referrers
Referrer type and ID: <class 'list_iterator'>, 140106867762048
  Other referrer: <class 'list_iterator'> (repr suppressed)


Referrer type and ID: <class 'list'>, 140106867953280
  Sequence length: 2

Referrer type and ID: <class 'dict'>, 140110378056448
  Dict keys: ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'gc']

Object 140110378056448, <class 'dict'> has 5 referrers
Referrer type and ID: <class 'list'>, 140106867953280
  Sequence length: 2


Referrer type and ID: <class 'module'>, 140110378059792
  Other referrer: <class 'module'> (repr suppressed)


Referrer type and ID: <class 'function'>, 140110377918976
  Other referrer: <class 'function'> (repr suppressed)


Referrer type and ID: <class 'function'>, 140110375700416
  Other referrer: <class 'function'> (repr suppressed)


Referrer type and ID: <class 'function'>, 140110376208640
  Other referrer: <class 'function'> (repr suppressed)

Referrer type and ID: <class 'dict'>, 140108265044224
  Dict keys: ['version', 'name', 'url_dir', 'filename', 'url', 'applied_transforms', 't1', 't2', 'pd', 'cls']

Object 140108265044224, <class 'dict'> has 2 referrers
Referrer type and ID: <class 'torchio.datasets.mni.colin.Colin27'>, 140106867906544
  Dict keys: ['t1', 't2', 'pd', 'cls']

Object 140106867906544, <class 'torchio.datasets.mni.colin.Colin27'> has 3 referrers
Referrer type and ID: <class 'list'>, 140106867953856
  Sequence length: 2


Referrer type and ID: <class 'list'>, 140106875282112
  Sequence length: 2


Referrer type and ID: <class 'dict'>, 140110378056448
  Dict keys: ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'gc']

Referrer type and ID: <class 'list'>, 140106867953856
  Sequence length: 2

Object 140106867953856, <class 'list'> has 2 referrers
Referrer type and ID: <class 'list_iterator'>, 140106867762048
  Other referrer: <class 'list_iterator'> (repr suppressed)


Referrer type and ID: <class 'list'>, 140106875282112
  Sequence length: 2

0.20.17 (Version with hotfix):

Checking memory usage before importing...

Memory usage: 12.12 MB

Importing...

Memory usage: 325.56 MB

Creating subjects list...

Memory usage: 326.85 MB

Instantiating SubjectsDataset...

Memory usage: 326.85 MB

Loading...
testing memory usage

Memory usage: 334.45 MB
testing memory usage

Memory usage: 334.60 MB
testing memory usage

Memory usage: 335.69 MB
testing memory usage

Memory usage: 335.34 MB
testing memory usage

Memory usage: 335.01 MB

Memory usage: 335.30 MB
Subject ID: 140460681742384
Subject keys: ['t1', 't2', 'pd', 'cls']
  t1: <class 'torchio.data.image.ScalarImage'>, Shape: (1, 16, 16, 16)
  t2: <class 'torchio.data.image.ScalarImage'>, Shape: (1, 16, 16, 16)
  pd: <class 'torchio.data.image.ScalarImage'>, Shape: (1, 16, 16, 16)
  cls: <class 'torchio.data.image.LabelMap'>, Shape: (1, 16, 16, 16)
Object 140460681742672, <class 'torchio.data.image.ScalarImage'> has 2 referrers
Referrer type and ID: <class 'torchio.datasets.mni.colin.Colin27'>, 140460681742384
  Dict keys: ['t1', 't2', 'pd', 'cls']

Object 140460681742384, <class 'torchio.datasets.mni.colin.Colin27'> has 2 referrers
Referrer type and ID: <class 'list'>, 140462379464896
  Sequence length: 2

Object 140462379464896, <class 'list'> has 2 referrers
Referrer type and ID: <class 'list_iterator'>, 140460681555888
  Other referrer: <class 'list_iterator'> (repr suppressed)


Referrer type and ID: <class 'list'>, 140460703722048
  Sequence length: 2

Referrer type and ID: <class 'dict'>, 140464191896384
  Dict keys: ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'gc']

Object 140464191896384, <class 'dict'> has 5 referrers
Referrer type and ID: <class 'list'>, 140460703722048
  Sequence length: 2


Referrer type and ID: <class 'module'>, 140464191899664
  Other referrer: <class 'module'> (repr suppressed)


Referrer type and ID: <class 'function'>, 140464191758848
  Other referrer: <class 'function'> (repr suppressed)


Referrer type and ID: <class 'function'>, 140464189540288
  Other referrer: <class 'function'> (repr suppressed)


Referrer type and ID: <class 'function'>, 140464190048512
  Other referrer: <class 'function'> (repr suppressed)

Referrer type and ID: <class 'dict'>, 140460703776704
  Dict keys: ['version', 'name', 'url_dir', 'filename', 'url', 't1', 't2', 'pd', 'cls', 'applied_transforms']

Object 140460703776704, <class 'dict'> has 2 referrers
Referrer type and ID: <class 'torchio.datasets.mni.colin.Colin27'>, 140460681742384
  Dict keys: ['t1', 't2', 'pd', 'cls']

Object 140460681742384, <class 'torchio.datasets.mni.colin.Colin27'> has 3 referrers
Referrer type and ID: <class 'list'>, 140462379464896
  Sequence length: 2


Referrer type and ID: <class 'list'>, 140460703722048
  Sequence length: 2


Referrer type and ID: <class 'dict'>, 140464191896384
  Dict keys: ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'gc']


Referrer type and ID: <class 'list'>, 140462379464896
  Sequence length: 2

Object 140462379464896, <class 'list'> has 2 referrers
Referrer type and ID: <class 'list_iterator'>, 140460681555888
  Other referrer: <class 'list_iterator'> (repr suppressed)


Referrer type and ID: <class 'list'>, 140460703722048
  Sequence length: 2

This time, it does look like there is an increase in memory consumption, not sure what changed, but this does mean I can investigate!

@fepegar
Copy link
Member

fepegar commented Aug 4, 2025 via email

@fepegar
Copy link
Member

fepegar commented Aug 4, 2025

Sorry, my script was not very helpful in the state I shared it. I've checked out v0.20.16 and run the script I attach below and got this output:

Checking memory usage before importing...

Memory usage: 19.73 MB

Importing...

Memory usage: 528.07 MB

Creating subjects list...

Memory usage: 528.70 MB

Instantiating SubjectsDataset...

Memory usage: 528.70 MB

Loading...

Memory usage: 539.53 MB

Memory usage: 1949.85 MB

Memory usage: 3360.27 MB

Memory usage: 4771.17 MB

Memory usage: 6182.50 MB

Memory usage: 7593.51 MB

Memory usage: 1953.14 MB

Memory usage: 3363.36 MB

Memory usage: 4773.75 MB

[...]

Change print_object_info and print_referrers to add verbosity.

import gc
import os

import psutil

num_subjects = 10
print_object_info = False
print_referrers = False

process = psutil.Process(os.getpid())


def print_referrers_recursively(obj, depth=2):
    """Recursively print referrers of an object."""
    if depth == 0:
        return
    referrers = gc.get_referrers(obj)
    print(f'Object {id(obj)} has {len(referrers)} referrers')
    for ref in referrers:
        print(f'Referrer type and ID: {type(ref)}, {id(ref)}')
        if isinstance(ref, dict):
            print(f'  Dict keys: {list(ref.keys())[:10]}')
        elif isinstance(ref, (list, tuple)):
            print(f'  Sequence length: {len(ref)}')
        elif isinstance(ref, type):
            print(f'  Class: {ref.__name__}')
        else:
            print(f'  Other referrer: {type(ref)} (repr suppressed)')
        print()
        # Recursively check referrers of the referrer
        print_referrers_recursively(ref, depth=depth - 1)
        print()


def print_memory_usage():
    """Print the current memory usage of the process."""
    mem = process.memory_info().rss / 1024**2  # Convert bytes to MB
    print(f'\nMemory usage: {mem:.2f} MB')

    if not print_object_info or not print_referrers:
        return

    import torch

    for obj in gc.get_objects():
        is_tensor = torch.is_tensor(obj)
        if is_tensor:
            if obj.shape[-1] == 16:
                continue
            print('-' * 80)
            tensor = obj if is_tensor else obj.data
            tensor_id = id(tensor)
            if print_object_info:
                print(
                    f'Type: {type(tensor)}, '
                    f'Shape: {tuple(tensor.shape)}, '
                    f'Dtype: {tensor.dtype}, '
                    f'ID: {tensor_id}'
                )
            if print_referrers:
                print_referrers_recursively(obj)
            print('-' * 80)


print('\nChecking memory usage before importing...')
print_memory_usage()

print('\nImporting...')
import torchio as tio

print_memory_usage()

print('\nCreating subjects list...')
subjects = [tio.datasets.Colin27(2008) for _ in range(num_subjects)]
transform = tio.CropOrPad(16)
print_memory_usage()

print('\nInstantiating SubjectsDataset...')
dataset = tio.SubjectsDataset(subjects, transform=transform)
print_memory_usage()

print('\nLoading...')
for subject in dataset:
    print_memory_usage()

@StijnvWijn
Copy link
Contributor Author

StijnvWijn commented Aug 4, 2025

Thanks for your quick response!

I'm sorry, apparently it does not show up on the server, but if I run it on my laptop, I can replicate your issue. I see that the pytorch versions are different, so maybe its related to that? There were some changes related to the default collation that made us stay with an older Pytorch version, but I will check.
Laptop:

Platform:   Linux-6.14.0-24-generic-x86_64-with-glibc2.41
TorchIO:    0.20.16
PyTorch:    2.7.0+cu126
SimpleITK:  2.5.0 (ITK 5.4)
NumPy:      2.3.1
Python:     3.11.12 (main, Apr  9 2025, 15:24:58) [GCC 14.2.0]

torchio 0.20.8:

Laptop Server
Checking memory usage before importing... Checking memory usage before importing...
Memory usage: 12.73 MB Memory usage: 12.08 MB
Importing... Importing...
Memory usage: 549.42 MB torchio version 0.20.8
Creating subjects list... Memory usage: 325.84 MB
Memory usage: 549.93 MB Creating subjects list...
Instantiating SubjectsDataset... Memory usage: 325.84 MB
Memory usage: 549.93 MB Instantiating SubjectsDataset...
Loading... Memory usage: 325.84 MB
Memory usage: 559.93 MB Loading...
Memory usage: 560.28 MB Memory usage: 335.29 MB
Memory usage: 560.19 MB Memory usage: 335.18 MB
Memory usage: 560.70 MB Memory usage: 335.02 MB
Memory usage: 559.93 MB Memory usage: 334.73 MB
Memory usage: 560.55 MB Memory usage: 334.42 MB
Memory usage: 560.43 MB Memory usage: 335.80 MB
Memory usage: 560.38 MB Memory usage: 334.57 MB
Memory usage: 559.98 MB Memory usage: 336.84 MB
Memory usage: 560.30 MB Memory usage: 335.12 MB
Memory usage: 336.25 MB

Torchio 0.20.10:

Laptop Server
Checking memory usage before importing... Checking memory usage before importing...
Memory usage: 12.61 MB Memory usage: 12.02 MB
Importing... Importing...
Torchio version: 0.20.10 torchio version 0.20.10
Memory usage: 550.13 MB Memory usage: 325.61 MB
Creating subjects list... Creating subjects list...
Memory usage: 550.64 MB Memory usage: 325.61 MB
Instantiating SubjectsDataset... Instantiating SubjectsDataset...
Memory usage: 550.64 MB Memory usage: 325.61 MB
Loading... Loading...
Memory usage: 1970.80 MB Memory usage: 334.90 MB
Memory usage: 3381.06 MB Memory usage: 334.96 MB
Memory usage: 3380.93 MB Memory usage: 335.95 MB
Memory usage: 4791.33 MB Memory usage: 336.54 MB
Memory usage: 4791.88 MB Memory usage: 336.51 MB
Memory usage: 6201.87 MB Memory usage: 336.25 MB
Memory usage: 7611.93 MB Memory usage: 335.81 MB
Memory usage: 4791.07 MB Memory usage: 336.03 MB
Memory usage: 6201.60 MB Memory usage: 336.52 MB
Memory usage: 4791.62 MB Memory usage: 336.50 MB

torchio 0.20.16:

Laptop Server Server (Pytorch 2.7.0)
Checking memory usage before importing... Checking memory usage before importing... Checking memory usage before importing...
Memory usage: 12.90 MB Memory usage: 12.02 MB Memory usage: 12.06 MB
Importing... Importing... Importing...
Torchio version: 0.20.16 torchio version 0.20.16 torchio version 0.20.16
Memory usage: 548.96 MB Memory usage: 325.77 MB Memory usage: 342.15 MB
Creating subjects list... Creating subjects list... Creating subjects list...
Memory usage: 549.48 MB Memory usage: 325.77 MB Memory usage: 342.67 MB
Instantiating SubjectsDataset... Instantiating SubjectsDataset... Instantiating SubjectsDataset...
Memory usage: 549.48 MB Memory usage: 325.77 MB Memory usage: 342.67 MB
Loading... Loading... Loading...
Memory usage: 1969.89 MB Memory usage: 335.65 MB Memory usage: 1763.43 MB
Memory usage: 1969.67 MB Memory usage: 334.47 MB Memory usage: 1762.62 MB
Memory usage: 3380.48 MB Memory usage: 335.05 MB Memory usage: 3173.18 MB
Memory usage: 1970.29 MB Memory usage: 335.25 MB Memory usage: 1763.83 MB
Memory usage: 3380.22 MB Memory usage: 335.32 MB Memory usage: 3173.73 MB
Memory usage: 4790.86 MB Memory usage: 335.53 MB Memory usage: 1764.70 MB
Memory usage: 3380.21 MB Memory usage: 336.39 MB Memory usage: 3174.76 MB
Memory usage: 4790.28 MB Memory usage: 336.33 MB Memory usage: 3174.92 MB
Memory usage: 3380.75 MB Memory usage: 336.06 MB Memory usage: 4585.34 MB
Memory usage: 4790.40 MB Memory usage: 335.98 MB Memory usage: 4585.33 MB

@fepegar
Copy link
Member

fepegar commented Aug 4, 2025

Interesting that it's different! But we definitely should check the PyTorch versions. However I'm not sure this would be about the collating function, because note in the script I'm not using a DataLoader!

@StijnvWijn
Copy link
Contributor Author

StijnvWijn commented Aug 4, 2025

Yeah it definitely seems related to the Pytorch version, if I install torch==2.7.0 on the server, it also starts taking memory quite quickly, but it does not seem to be a monotonous increase, which is kind of strange. I have to go do some other things now, but I'll investigate this some more later. It seems to show up starting torch==2.3.0.
Torch=2.7.0:
memory_torch=2 7 0
Torch=2.2.0
memory_torch=2 2 0

@StijnvWijn
Copy link
Contributor Author

StijnvWijn commented Aug 7, 2025

I have found an issue that might be related: If I create a new tio.Image inside of a loop, it seems to not be garbage collected. Not sure why yet, but it also causes a linear increase in memory consumption if I run the following snippet:

for _ in range(10):
    new_tensor = torch.ones((1, 100,100,100), dtype=torch.float32)
    new_image = tio.Image(tensor = new_tensor)
    print(f'Created tensor with shape {new_tensor.shape} and id {id(new_tensor)}')
    print_memory_usage()

It does not happen if the new_image is just a tensor, an ordinary dict like {'tensor': new_tensor} or a much simpler class like below, so I am not yet sure where the issue lies...

class TestImage(dict):
    def __init__(self, tensor, **kwargs):
        
        self['data'] = tensor
        self['affine'] = None  # Placeholder for affine, if needed
        super().__init__(**kwargs)

Soo I asked Gemini and it found the issue, it turns out that it is related to the self._check_data_loader() in the init function and if I comment it out, it indeed solves the issue, also with the cropping transform. I will copy-paste its explanation. Not sure if we can fix it without breaking the functionality, what do you think @fepegar @nicoloesch ?

The Cause of the Leak: A "Frame Leak"

Here is the step-by-step mechanism of the leak:

Object Creation: When your loop calls new_image = tio.Image(tensor=new_tensor), the Image.__init__ method begins execution. At this moment, the Python interpreter creates an "execution frame" for this method call. This frame holds all local variables, including self, which points to the new Image object being created.

The Problematic Call: Inside __init__, the line self._check_data_loader() is executed.

Inspecting the Stack: The _check_data_loader() method calls in_torch_loader(). To figure out if it's "in a torch loader," this utility almost certainly uses Python's inspect module to look at the current call stack (e.g., by calling inspect.stack()).

Capturing the Frame: When inspect.stack() is called, it gets a reference to the execution frame of Image.__init__. This frame, as we know, contains a reference to the Image object via the self variable.

Caching the Warning: The _check_data_loader() method then calls warnings.warn(). Python's warnings module is designed to be helpful and can cache information about where and why a warning was triggered. This cached information often includes the stack frames.

This creates a persistent reference cycle that the garbage collector cannot break:

The Image object is referenced by...

...the self variable inside its own __init__ frame, which is referenced by...

...an inspect.FrameInfo object, which is now held by...

...the internal cache of the warnings module.

Because the warnings module holds onto this reference, the Image object and its large data tensor can never be freed, causing the memory usage to climb with every new image created in your loop.

@fepegar
Copy link
Member

fepegar commented Aug 8, 2025

Wow! Great investigation! That cheeky warnings module cache.
I've cherry-picked the commit of this PR and looks good <3
I'm sorry I blamed the memory issues on this one instead of the data loader thing!

@StijnvWijn
Copy link
Contributor Author

Awesome!

No problem, It also took me quite some time to find the issue, did not even know this was how that warnings module worked haha.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

PatchSampler slowdown after upgrading

3 participants