@@ -124,7 +124,7 @@ def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3):
124124 # s0 is upsampled 2x from s1, representing feature sizes at two resolutions.
125125 # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization)
126126 # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims)
127- self .ram_cost = in_channel / out_channel * 2 ** self ._spatial_dims + 3
127+ self .ram_cost = in_channel / out_channel * 2 ** self ._spatial_dims + 3
128128
129129
130130class MixedOp (nn .Module ):
@@ -330,7 +330,7 @@ def __init__(
330330 # define downsample stems before DiNTS search
331331 if use_downsample :
332332 self .stem_down [str (res_idx )] = StemTS (
333- nn .Upsample (scale_factor = 1 / (2 ** res_idx ), mode = mode , align_corners = True ),
333+ nn .Upsample (scale_factor = 1 / (2 ** res_idx ), mode = mode , align_corners = True ),
334334 conv_type (
335335 in_channels = in_channels ,
336336 out_channels = self .filter_nums [res_idx ],
@@ -373,7 +373,7 @@ def __init__(
373373
374374 else :
375375 self .stem_down [str (res_idx )] = StemTS (
376- nn .Upsample (scale_factor = 1 / (2 ** res_idx ), mode = mode , align_corners = True ),
376+ nn .Upsample (scale_factor = 1 / (2 ** res_idx ), mode = mode , align_corners = True ),
377377 conv_type (
378378 in_channels = in_channels ,
379379 out_channels = self .filter_nums [res_idx ],
@@ -789,7 +789,7 @@ def get_ram_cost_usage(self, in_size, full: bool = False):
789789 image_size = np .array (in_size [- self ._spatial_dims :])
790790 sizes = []
791791 for res_idx in range (self .num_depths ):
792- sizes .append (batch_size * self .filter_nums [res_idx ] * (image_size // (2 ** res_idx )).prod ())
792+ sizes .append (batch_size * self .filter_nums [res_idx ] * (image_size // (2 ** res_idx )).prod ())
793793 sizes = torch .tensor (sizes ).to (torch .float32 ).to (self .device ) / (2 ** (int (self .use_downsample )))
794794 probs_a , arch_code_prob_a = self .get_prob_a (child = False )
795795 cell_prob = F .softmax (self .log_alpha_c , dim = - 1 )
@@ -807,7 +807,7 @@ def get_ram_cost_usage(self, in_size, full: bool = False):
807807 * (1 + (ram_cost [blk_idx , path_idx ] * cell_prob [blk_idx , path_idx ]).sum ())
808808 * sizes [self .arch_code2out [path_idx ]]
809809 )
810- return usage * 32 / 8 / 1024 ** 2
810+ return usage * 32 / 8 / 1024 ** 2
811811
812812 def get_topology_entropy (self , probs ):
813813 """
0 commit comments