1111from cinema import CineMA , patchify , unpatchify
1212
1313
14- def run () -> None :
14+ def run (device : torch . device , dtype : torch . dtype ) -> None :
1515 """Run MAE reconstruction."""
1616 # load model
1717 model = CineMA .from_pretrained ()
18+ model .to (device )
1819 model .eval ()
1920
2021 # load sample data and form a batch of size 1
2122 transform = Compose (
2223 [
2324 ScaleIntensityd (keys = ("sax" , "lax_2c" , "lax_3c" , "lax_4c" ), allow_missing_keys = True ),
24- SpatialPadd (keys = "sax" , spatial_size = (192 , 192 , 16 ), method = "end" , lazy = True , allow_missing_keys = True ),
25+ SpatialPadd (keys = "sax" , spatial_size = (192 , 192 , 16 ), method = "end" ),
2526 SpatialPadd (
2627 keys = ("lax_2c" , "lax_3c" , "lax_4c" ),
2728 spatial_size = (256 , 256 ),
@@ -47,17 +48,17 @@ def run() -> None:
4748 )
4849 t = 25 # which time frame to use
4950 batch = {
50- "sax" : sax_image [None , ..., t ]. to ( dtype = torch . float32 ) ,
51- "lax_2c" : lax_2c_image [None , ..., 0 , t ]. to ( dtype = torch . float32 ) ,
52- "lax_3c" : lax_3c_image [None , ..., 0 , t ]. to ( dtype = torch . float32 ) ,
53- "lax_4c" : lax_4c_image [None , ..., 0 , t ]. to ( dtype = torch . float32 ) ,
51+ "sax" : sax_image [None , ..., t ],
52+ "lax_2c" : lax_2c_image [None , ..., 0 , t ],
53+ "lax_3c" : lax_3c_image [None , ..., 0 , t ],
54+ "lax_4c" : lax_4c_image [None , ..., 0 , t ],
5455 }
5556 batch = transform (batch )
5657 print (f"SAX view had originally { sax_image .shape [- 2 ]} slices, now zero-padded to { batch ['sax' ].shape [- 1 ]} slices." ) # noqa: T201
57- batch = {k : v [None , ...] for k , v in batch .items ()} # batch size 1
58+ batch = {k : v [None , ...]. to ( device = device , dtype = dtype ) for k , v in batch .items ()}
5859
5960 # forward
60- with torch .no_grad (), torch .autocast ("cuda" , enabled = torch .cuda .is_available ()):
61+ with torch .no_grad (), torch .autocast ("cuda" , dtype = dtype , enabled = torch .cuda .is_available ()):
6162 _ , pred_dict , enc_mask_dict , _ = model (batch , enc_mask_ratio = 0.75 )
6263
6364 # visualize
@@ -76,8 +77,8 @@ def run() -> None:
7677 patch_size = model .dec_patch_size_dict [view ],
7778 grid_size = model .enc_down_dict [view ].patch_embed .grid_size ,
7879 )
79- reconstructed = reconstructed [0 , 0 ].detach ().numpy ()
80- image = batch [view ][0 , 0 ].detach ().numpy ()
80+ reconstructed = reconstructed [0 , 0 ].detach ().cpu (). numpy ()
81+ image = batch [view ][0 , 0 ].detach ().cpu (). numpy ()
8182 error = np .abs (reconstructed - image )
8283
8384 if view == "sax" :
@@ -104,4 +105,10 @@ def run() -> None:
104105
105106
106107if __name__ == "__main__" :
107- run ()
108+ dtype , device = torch .float32 , torch .device ("cpu" )
109+ if torch .cuda .is_available ():
110+ device = torch .device ("cuda" )
111+ if torch .cuda .is_bf16_supported ():
112+ dtype = torch .bfloat16
113+
114+ run (device , dtype )
0 commit comments