-
Notifications
You must be signed in to change notification settings - Fork 0
Implementation of ResUNet architecture #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: hanaol/model-hydra
Are you sure you want to change the base?
Conversation
|
Exciting! I'm glad this wasn't that painful to implement too!! Great idea! |
forklady42
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments, but none are blocking. Feel free to address them and merge.
| out = torch.cat([out, skips.pop()], dim=1) | ||
| out = dec(out) | ||
| out = self.out_conv(out) | ||
| out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slight risk of division by zero. Good practice to add a small epsilon to reduce that possibility, e.g. out / (torch.sum(....) + 1e-8)
|
|
||
| def upsample(cin, cout): | ||
| return nn.Sequential( | ||
| nn.ConvTranspose3d(cin, cout, kernel_size=2, stride=2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a blocker here but curious about how you're thinking about the repeating nature of materials when upsampling. I know circular padding doesn't have a clear interpretation for transposed convolutions, so this seems fine. Another option, potentially worth trying would be to upsample first and then do convolution with circular padding.
| else: | ||
| self.skip = nn.Identity() | ||
|
|
||
| def forward(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be nice to have gradient checkpointing here
| def forward(self, x): | ||
| h = self.act(self.norm1(self.conv1(x))) | ||
| h = self.norm2(self.conv2(h)) | ||
| return self.act(h + self.skip(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, you're intentionally forcing this line and the function on line 22 to use the same instance of PReLU and learn the same parameter rather than each being independent and learning separate parameters, right?
This PR implements a ResUNet architecture as an alternative to the existing ResNet model. The original ResNet operates in the data dimension space, where an explicit initial downsampling step is required to accelerate training. In the ResNet/QM9 workflow, this was achieved by removing grid cells (pixels), which can lead to a loss of information.
In contrast, the proposed ResUNet architecture includes both encoder and decoder components, allowing it to handle downsampling and upsampling internally and more efficiently. Initial tests on the QM9 dataset indicate that the ResUNet achieves comparable performance and training speed to the ResNet/QM9 model. A more thorough comparison, however, will require further hyperparameter tuning for the ResUNet architecture.