Skip to content

Commit 47ebc1c

Browse files
feat: 💡 remove restrictive scaffolding for computing number of feature maps
1 parent 3afbc2a commit 47ebc1c

File tree

2 files changed

+114
-128
lines changed

2 files changed

+114
-128
lines changed

solution.py

Lines changed: 94 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@ def __init__(
316316
msg = "Only allowing odd kernel sizes."
317317
raise ValueError(msg)
318318

319+
self.in_channels = in_channels
320+
self.out_channels = out_channels
321+
self.kernel_size = kernel_size
322+
self.padding = padding
323+
319324
# TASK 3.1: Initialize your modules and define layers.
320325
# YOUR CODE HERE
321326

@@ -355,6 +360,11 @@ def __init__(
355360
msg = "Only allowing odd kernel sizes."
356361
raise ValueError(msg)
357362

363+
self.in_channels = in_channels
364+
self.out_channels = out_channels
365+
self.kernel_size = kernel_size
366+
self.padding = padding
367+
358368
# SOLUTION 3.1: Initialize your modules and define layers.
359369
self.conv_pass = torch.nn.Sequential(
360370
torch.nn.Conv2d(
@@ -577,8 +587,7 @@ def forward(self, x):
577587
# <h4>Task 6: U-Net Implementation</h4>
578588
# <p>Now we will implement our U-Net! We have written some of it for you - follow the steps below to fill in the missing parts.</p>
579589
# <ol>
580-
# <li>Write the helper functions <code>compute_fmaps_encoder</code> and <code>compute_fmaps_decoder</code> that compute the number of input and output feature maps at each level of the U-Net.</li>
581-
# <li>Declare a list of encoder (left) and decoder (right) ConvPasses depending on your depth using the helper functions you wrote above. Consider the special case at the bottom of the U-Net carefully!</li>
590+
# <li>Declare a list of encoder (left) and decoder (right) ConvPasses. Carefully consider the input and output feature maps for each ConvPass!</li>
582591
# <li>Declare an Upsample, Downsample, CropAndConcat, and OutputConv block.</li>
583592
# <li>Implement the <code>forward</code> function, applying the modules you declared above in the proper order.</li>
584593
# </ol>
@@ -648,61 +657,39 @@ def __init__(
648657

649658
# left convolutional passes
650659
self.left_convs = torch.nn.ModuleList()
651-
# TASK 6.2A: Initialize list here
652-
660+
# TASK 6.1A: Initialize list here
661+
# Loop through each level of the encoder from top (level=0) to bottom (level=self.depth - 1)
662+
for level in range(self.depth):
663+
# conv =
664+
# Adding conv module to the list
665+
self.left_convs.append(conv)
653666
# right convolutional passes
654667
self.right_convs = torch.nn.ModuleList()
655-
# TASK 6.2B: Initialize list here
656-
657-
# TASK 6.3: Initialize other modules here
658-
659-
def compute_fmaps_encoder(self, level: int) -> tuple[int, int]:
660-
"""Compute the number of input and output feature maps for
661-
a conv block at a given level of the UNet encoder (left side).
662-
663-
Args:
664-
level (int): The level of the U-Net which we are computing
665-
the feature maps for. Level 0 is the input level, level 1 is
666-
the first downsampled layer, and level=depth - 1 is the bottom layer.
667-
668-
Output (tuple[int, int]): The number of input and output feature maps
669-
of the encoder convolutional pass in the given level.
670-
"""
671-
# TASK 6.1A: Implement this function
672-
pass
673-
674-
def compute_fmaps_decoder(self, level: int) -> tuple[int, int]:
675-
"""Compute the number of input and output feature maps for a conv block
676-
at a given level of the UNet decoder (right side). Note:
677-
The bottom layer (depth - 1) is considered an "encoder" conv pass,
678-
so this function is only valid up to depth - 2.
679-
680-
Args:
681-
level (int): The level of the U-Net which we are computing
682-
the feature maps for. Level 0 is the input level, level 1 is
683-
the first downsampled layer, and level=depth - 1 is the bottom layer.
684-
685-
Output (tuple[int, int]): The number of input and output feature maps
686-
of the encoder convolutional pass in the given level.
687-
"""
688-
# TASK 6.1B: Implement this function
689-
pass
668+
# TASK 6.1B: Initialize list here
669+
# Loop through each level of the decoder from top (level=0) to one above bottom (level=self.depth - 2)
670+
for level in range(self.depth - 1):
671+
# Initialize conv module
672+
# conv =
673+
# Adding conv module to the list
674+
self.right_convs.append(conv)
675+
676+
# TASK 6.2: Initialize other modules here
690677

691678
def forward(self, x):
692679
# left side
693680
# Hint - you will need the outputs of each convolutional block in the encoder for the skip connection, so you need to hold on to those output tensors
694681
for i in range(self.depth - 1):
695-
# TASK 6.4A: Implement encoder here
682+
# TASK 6.3A: Implement encoder here
696683
...
697684

698685
# bottom
699-
# TASK 6.4B: Implement bottom of U-Net here
686+
# TASK 6.3B: Implement bottom of U-Net here
700687

701688
# right
702689
for i in range(0, self.depth - 1)[::-1]:
703-
# TASK 6.4C: Implement decoder here
690+
# TASK 6.3C: Implement decoder here
704691
...
705-
# TASK 6.4D: Apply the final convolution and return the output
692+
# TASK 6.3D: Apply the final convolution and return the output
706693
return
707694

708695

@@ -769,28 +756,31 @@ def __init__(
769756

770757
# left convolutional passes
771758
self.left_convs = torch.nn.ModuleList()
772-
# SOLUTION 6.2A: Initialize list here
759+
# SOLUTION 6.1A: Initialize list here
760+
# Loop through each level of the encoder from top (level=0) to bottom (level=self.depth - 1)
773761
for level in range(self.depth):
774762
fmaps_in, fmaps_out = self.compute_fmaps_encoder(level)
775-
self.left_convs.append(
776-
ConvBlock(fmaps_in, fmaps_out, self.kernel_size, self.padding)
777-
)
763+
conv = ConvBlock(fmaps_in, fmaps_out, self.kernel_size, self.padding)
764+
# Adding conv module to the list
765+
self.left_convs.append(conv)
778766

779767
# right convolutional passes
780768
self.right_convs = torch.nn.ModuleList()
781-
# SOLUTION 6.2B: Initialize list here
769+
# SOLUTION 6.1B: Initialize list here
770+
# Loop through each level of the decoder from top (level=0) to one above bottom (level=self.depth - 2)
782771
for level in range(self.depth - 1):
783772
fmaps_in, fmaps_out = self.compute_fmaps_decoder(level)
784-
self.right_convs.append(
785-
ConvBlock(
786-
fmaps_in,
787-
fmaps_out,
788-
self.kernel_size,
789-
self.padding,
790-
)
773+
# Initialize conv module
774+
conv = ConvBlock(
775+
fmaps_in,
776+
fmaps_out,
777+
self.kernel_size,
778+
self.padding,
791779
)
780+
# Adding conv module to the list
781+
self.right_convs.append(conv)
792782

793-
# SOLUTION 6.3: Initialize other modules here
783+
# SOLUTION 6.2: Initialize other modules here
794784
self.downsample = Downsample(self.downsample_factor)
795785
self.upsample = torch.nn.Upsample(
796786
scale_factor=self.downsample_factor,
@@ -850,26 +840,26 @@ def forward(self, x):
850840
convolution_outputs = []
851841
layer_input = x
852842
for i in range(self.depth - 1):
853-
# SOLUTION 6.4A: Implement encoder here
843+
# SOLUTION 6.3A: Implement encoder here
854844
conv_out = self.left_convs[i](layer_input)
855845
convolution_outputs.append(conv_out)
856846
downsampled = self.downsample(conv_out)
857847
layer_input = downsampled
858848

859849
# bottom
860-
# SOLUTION 6.4B: Implement bottom of U-Net here
850+
# SOLUTION 6.3B: Implement bottom of U-Net here
861851
conv_out = self.left_convs[-1](layer_input)
862852
layer_input = conv_out
863853

864854
# right
865855
for i in range(0, self.depth - 1)[::-1]:
866-
# SOLUTION 6.4C: Implement decoder here
856+
# SOLUTION 6.3C: Implement decoder here
867857
upsampled = self.upsample(layer_input)
868858
concat = self.crop_and_concat(convolution_outputs[i], upsampled)
869859
conv_output = self.right_convs[i](concat)
870860
layer_input = conv_output
871861

872-
# SOLUTION 6.4D: Apply the final convolution and return the output
862+
# SOLUTION 6.3D: Apply the final convolution and return the output
873863
return self.final_conv(layer_input)
874864

875865

@@ -1317,6 +1307,11 @@ def __init__(
13171307
msg = "Only allowing odd kernel sizes."
13181308
raise ValueError(msg)
13191309

1310+
self.in_channels = in_channels
1311+
self.out_channels = out_channels
1312+
self.kernel_size = kernel_size
1313+
self.padding = padding
1314+
13201315
# TASK 10C: Initialize your modules and define layers.
13211316
# Use the convolution module matching `ndim`.
13221317
# YOUR CODE HERE
@@ -1441,72 +1436,45 @@ def __init__(
14411436
# left convolutional passes
14421437
self.left_convs = torch.nn.ModuleList()
14431438
# TASK 10G: Initialize list here
1444-
# After you implemented the conv pass you can copy this from TASK 6.2A,
1445-
# but make sure to pass the ndim argument
1439+
# Loop through each level of the encoder from top (level=0) to bottom (level=self.depth - 1)
1440+
for level in range(self.depth):
1441+
# conv =
1442+
# Adding conv module to the list
1443+
self.left_convs.append(conv)
14461444

14471445
# right convolutional passes
14481446
self.right_convs = torch.nn.ModuleList()
14491447
# TASK 10H: Initialize list here
1450-
# After you implemented the conv pass you can copy this from TASK 6.2B,
1451-
# but make sure to pass the ndim argument
1448+
# Loop through each level of the decoder from top (level=0) to one above bottom (level=self.depth - 2)
1449+
for level in range(self.depth - 1):
1450+
# Initialize conv module
1451+
# conv =
1452+
# Adding conv module to the list
1453+
self.right_convs.append(conv)
14521454

14531455
# TASK 10I: Initialize other modules here
1454-
# Same here, copy over from TASK 6.3, but make sure to add the ndim argument
1456+
# Same here, copy over from TASK 6.2, but make sure to add the ndim argument
14551457
# as needed.
14561458

1457-
def compute_fmaps_encoder(self, level: int) -> tuple[int, int]:
1458-
"""Compute the number of input and output feature maps for
1459-
a conv block at a given level of the UNet encoder (left side).
1460-
1461-
Args:
1462-
level (int): The level of the U-Net which we are computing
1463-
the feature maps for. Level 0 is the input level, level 1 is
1464-
the first downsampled layer, and level=depth - 1 is the bottom layer.
1465-
1466-
Output (tuple[int, int]): The number of input and output feature maps
1467-
of the encoder convolutional pass in the given level.
1468-
"""
1469-
# TASK 10J: Implement this function.
1470-
# You can copy from TASK 6.1A
1471-
pass
1472-
1473-
def compute_fmaps_decoder(self, level: int) -> tuple[int, int]:
1474-
"""Compute the number of input and output feature maps for a conv block
1475-
at a given level of the UNet decoder (right side). Note:
1476-
The bottom layer (depth - 1) is considered an "encoder" conv pass,
1477-
so this function is only valid up to depth - 2.
1478-
1479-
Args:
1480-
level (int): The level of the U-Net which we are computing
1481-
the feature maps for. Level 0 is the input level, level 1 is
1482-
the first downsampled layer, and level=depth - 1 is the bottom layer.
1483-
1484-
Output (tuple[int, int]): The number of input and output feature maps
1485-
of the encoder convolutional pass in the given level.
1486-
"""
1487-
# TASK 10K: Implement this function.
1488-
# You can copy from TASK 6.1B
1489-
pass
1490-
14911459
def forward(self, x):
14921460
# left side
14931461
# Hint - you will need the outputs of each convolutional block in the encoder for the skip connection, so you need to hold on to those output tensors
14941462
for i in range(self.depth - 1):
14951463
# TASK 10L: Implement encoder here
1496-
# Copy from TASK 6.4A
1464+
# Copy from TASK 6.3A
14971465
...
14981466

14991467
# bottom
15001468
# TASK 10M: Implement bottom of U-Net here
1501-
# Copy from TASK 6.4B
1469+
# Copy from TASK 6.3B
15021470

15031471
# right
15041472
for i in range(0, self.depth - 1)[::-1]:
15051473
# TASK 10N: Implement decoder here
1506-
# Copy from TASK 6.4C
1474+
# Copy from TASK 6.3C
15071475
...
15081476
# TASK 10O: Apply the final convolution and return the output
1509-
# Copy from TASK 6.4D
1477+
# Copy from TASK 6.3D
15101478
return
15111479

15121480

@@ -1581,6 +1549,11 @@ def __init__(
15811549
msg = "Only allowing odd kernel sizes."
15821550
raise ValueError(msg)
15831551

1552+
self.in_channels = in_channels
1553+
self.out_channels = out_channels
1554+
self.kernel_size = kernel_size
1555+
self.padding = padding
1556+
15841557
# SOLUTION 10C: Initialize your modules and define layers.
15851558
# Use the convolution module matching `ndim`.
15861559
# YOUR CODE HERE
@@ -1724,30 +1697,33 @@ def __init__(
17241697
# left convolutional passes
17251698
self.left_convs = torch.nn.ModuleList()
17261699
# SOLUTION 10G: Initialize list here
1727-
# After you implemented the conv pass you can copy this from TASK 6.2A,
1700+
# After you implemented the conv pass you can copy this from TASK 6.1A,
17281701
# but make sure to pass the ndim argument
1702+
# Loop through each level of the encoder from top (level=0) to bottom (level=self.depth - 1)
17291703
for level in range(self.depth):
17301704
fmaps_in, fmaps_out = self.compute_fmaps_encoder(level)
1731-
self.left_convs.append(
1732-
ConvBlock(
1733-
fmaps_in, fmaps_out, self.kernel_size, self.padding, ndim=ndim
1734-
)
1705+
conv = ConvBlock(
1706+
fmaps_in, fmaps_out, self.kernel_size, self.padding, ndim=ndim
17351707
)
1708+
# Adding conv module to the list
1709+
self.left_convs.append(conv)
17361710
# right convolutional passes
17371711
self.right_convs = torch.nn.ModuleList()
17381712
# SOLUTION 10H: Initialize list here
1739-
# After you implemented the conv pass you can copy this from TASK 6.2B,
1713+
# After you implemented the conv pass you can copy this from TASK 6.1B,
17401714
# but make sure to pass the ndim argument
1715+
# Loop through each level of the decoder from top (level=0) to one above bottom (level=self.depth - 2)
17411716
for level in range(self.depth - 1):
17421717
fmaps_in, fmaps_out = self.compute_fmaps_decoder(level)
1743-
self.right_convs.append(
1744-
ConvBlock(
1745-
fmaps_in, fmaps_out, self.kernel_size, self.padding, ndim=ndim
1746-
)
1718+
# Initialize conv module
1719+
conv = ConvBlock(
1720+
fmaps_in, fmaps_out, self.kernel_size, self.padding, ndim=ndim
17471721
)
1722+
# Adding conv module to the list
1723+
self.right_convs.append(conv)
17481724

17491725
# SOLUTION 10I: Initialize other modules here
1750-
# Same here, copy over from TASK 6.3, but make sure to add the ndim argument
1726+
# Same here, copy over from TASK 6.2, but make sure to add the ndim argument
17511727
# as needed.
17521728
self.downsample = Downsample(self.downsample_factor, ndim=ndim)
17531729
self.upsample = torch.nn.Upsample(
@@ -1775,7 +1751,6 @@ def compute_fmaps_encoder(self, level: int) -> tuple[int, int]:
17751751
of the encoder convolutional pass in the given level.
17761752
"""
17771753
# SOLUTION 10J: Implement this function.
1778-
# You can copy from TASK 6.1A
17791754
if level == 0:
17801755
fmaps_in = self.in_channels
17811756
else:
@@ -1799,7 +1774,6 @@ def compute_fmaps_decoder(self, level: int) -> tuple[int, int]:
17991774
of the encoder convolutional pass in the given level.
18001775
"""
18011776
# SOLUTION 10K: Implement this function.
1802-
# You can copy from TASK 6.1B
18031777
fmaps_out = self.num_fmaps * self.fmap_inc_factor ** (level)
18041778
concat_fmaps = self.compute_fmaps_encoder(level)[
18051779
1
@@ -1815,29 +1789,29 @@ def forward(self, x):
18151789
layer_input = x
18161790
for i in range(self.depth - 1):
18171791
# SOLUTION 10L: Implement encoder here
1818-
# Copy from TASK 6.4A
1792+
# Copy from TASK 6.3A
18191793
conv_out = self.left_convs[i](layer_input)
18201794
convolution_outputs.append(conv_out)
18211795
downsampled = self.downsample(conv_out)
18221796
layer_input = downsampled
18231797

18241798
# bottom
18251799
# SOLUTION 10M: Implement bottom of U-Net here
1826-
# Copy from TASK 6.4B
1800+
# Copy from TASK 6.3B
18271801
conv_out = self.left_convs[-1](layer_input)
18281802
layer_input = conv_out
18291803

18301804
# right
18311805
for i in range(0, self.depth - 1)[::-1]:
18321806
# SOLUTION 10N: Implement decoder here
1833-
# Copy from TASK 6.4C
1807+
# Copy from TASK 6.3C
18341808
upsampled = self.upsample(layer_input)
18351809
concat = self.crop_and_concat(convolution_outputs[i], upsampled)
18361810
conv_output = self.right_convs[i](concat)
18371811
layer_input = conv_output
18381812

18391813
# SOLUTION 10O: Apply the final convolution and return the output
1840-
# Copy from TASK 6.4D
1814+
# Copy from TASK 6.3D
18411815
return self.final_conv(layer_input)
18421816

18431817

0 commit comments

Comments
 (0)