File tree Expand file tree Collapse file tree 2 files changed +13
-1
lines changed Expand file tree Collapse file tree 2 files changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -236,6 +236,7 @@ def __init__(
236236 self ,
237237 dim ,
238238 * ,
239+ image_size ,
239240 patch_size ,
240241 channels = 3 ,
241242 channels_out = None ,
@@ -244,6 +245,13 @@ def __init__(
244245 depth = 4 ,
245246 ):
246247 super ().__init__ ()
248+ assert exists (image_size )
249+ assert (image_size % patch_size ) == 0
250+
251+ num_patches_height_width = image_size // patch_size
252+
253+ self .pos_emb = nn .Parameter (torch .zeros (dim , num_patches_height_width , num_patches_height_width ))
254+
247255 channels_out = default (channels_out , channels )
248256
249257 patch_dim = channels * (patch_size ** 2 )
@@ -272,6 +280,8 @@ def __init__(
272280
273281 def forward (self , x ):
274282 x = self .to_tokens (x )
283+ x = x + self .pos_emb
284+
275285 x = self .transformer (x )
276286 return self .to_patches (x )
277287
@@ -283,6 +293,7 @@ def __init__(
283293 fmap_size ,
284294 dim ,
285295 dynamic = True ,
296+ image_size = None ,
286297 dim_head = 32 ,
287298 heads = 4 ,
288299 depth = 4 ,
@@ -412,6 +423,7 @@ def __init__(
412423 if conditioning_klass == Conditioning :
413424 conditioning_klass = partial (
414425 Conditioning ,
426+ image_size = image_size ,
415427 dynamic = dynamic_ff_parser_attn_map ,
416428 ** conditioning_kwargs
417429 )
Original file line number Diff line number Diff line change 33setup (
44 name = 'med-seg-diff-pytorch' ,
55 packages = find_packages (exclude = []),
6- version = '0.3.0 ' ,
6+ version = '0.3.1 ' ,
77 license = 'MIT' ,
88 description = 'MedSegDiff - SOTA medical image segmentation - Pytorch' ,
99 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments