Skip to content

Conversation

@hanaol
Copy link
Collaborator

@hanaol hanaol commented Jan 27, 2026

This PR implements a ResUNet architecture as an alternative to the existing ResNet model. The original ResNet operates in the data dimension space, where an explicit initial downsampling step is required to accelerate training. In the ResNet/QM9 workflow, this was achieved by removing grid cells (pixels), which can lead to a loss of information.

In contrast, the proposed ResUNet architecture includes both encoder and decoder components, allowing it to handle downsampling and upsampling internally and more efficiently. Initial tests on the QM9 dataset indicate that the ResUNet achieves comparable performance and training speed to the ResNet/QM9 model. A more thorough comparison, however, will require further hyperparameter tuning for the ResUNet architecture.

@Andrew-S-Rosen
Copy link
Member

Exciting! I'm glad this wasn't that painful to implement too!! Great idea!

@hanaol hanaol requested a review from forklady42 January 30, 2026 17:00
@hanaol hanaol marked this pull request as ready for review January 30, 2026 17:01
Copy link
Collaborator

@forklady42 forklady42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments, but none are blocking. Feel free to address them and merge.

out = torch.cat([out, skips.pop()], dim=1)
out = dec(out)
out = self.out_conv(out)
out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight risk of division by zero. Good practice to add a small epsilon to reduce that possibility, e.g. out / (torch.sum(....) + 1e-8)


def upsample(cin, cout):
return nn.Sequential(
nn.ConvTranspose3d(cin, cout, kernel_size=2, stride=2),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker here but curious about how you're thinking about the repeating nature of materials when upsampling. I know circular padding doesn't have a clear interpretation for transposed convolutions, so this seems fine. Another option, potentially worth trying would be to upsample first and then do convolution with circular padding.

else:
self.skip = nn.Identity()

def forward(self, x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be nice to have gradient checkpointing here

def forward(self, x):
h = self.act(self.norm1(self.conv1(x)))
h = self.norm2(self.conv2(h))
return self.act(h + self.skip(x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, you're intentionally forcing this line and the function on line 22 to use the same instance of PReLU and learn the same parameter rather than each being independent and learning separate parameters, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants