@@ -133,12 +133,11 @@ def init_scoring_methods(self) -> Dict[str, ScoringMethod]:
133133
134134def main ():
135135 import argparse
136+ import shutil
136137 from pathlib import Path
137138
138- from monailabel .config import settings
139+ from monailabel .utils . others . generic import device_list , file_ext
139140
140- settings .MONAI_LABEL_DATASTORE_AUTO_RELOAD = False
141- settings .MONAI_LABEL_DATASTORE_FILE_EXT = ["*.png" , "*.jpg" , "*.jpeg" , ".nii" , ".nii.gz" ]
142141 os .putenv ("MASTER_ADDR" , "127.0.0.1" )
143142 os .putenv ("MASTER_PORT" , "1234" )
144143
@@ -154,43 +153,71 @@ def main():
154153
155154 parser = argparse .ArgumentParser ()
156155 parser .add_argument ("-s" , "--studies" , default = studies )
156+ parser .add_argument ("-m" , "--model" , default = "wholeBody_ct_segmentation" )
157+ parser .add_argument ("-t" , "--test" , default = "infer" , choices = ("train" , "infer" , "batch_infer" ))
157158 args = parser .parse_args ()
158159
159160 app_dir = os .path .dirname (__file__ )
160161 studies = args .studies
162+ conf = {
163+ "models" : args .model ,
164+ "preload" : "false" ,
165+ }
166+
167+ app = MyApp (app_dir , studies , conf )
168+
169+ # Infer
170+ if args .test == "infer" :
171+ sample = app .next_sample (request = {"strategy" : "first" })
172+ image_id = sample ["id" ]
173+ image_path = sample ["path" ]
174+
175+ # Run on all devices
176+ for device in device_list ():
177+ res = app .infer (request = {"model" : args .model , "image" : image_id , "device" : device })
178+ label = res ["file" ]
179+ label_json = res ["params" ]
180+ test_dir = os .path .join (args .studies , "test_labels" )
181+ os .makedirs (test_dir , exist_ok = True )
182+
183+ label_file = os .path .join (test_dir , image_id + file_ext (image_path ))
184+ shutil .move (label , label_file )
185+
186+ print (label_json )
187+ print (f"++++ Image File: { image_path } " )
188+ print (f"++++ Label File: { label_file } " )
189+ break
190+ return
191+
192+ # Batch Infer
193+ if args .test == "batch_infer" :
194+ app .batch_infer (
195+ request = {
196+ "model" : args .model ,
197+ "multi_gpu" : False ,
198+ "save_label" : True ,
199+ "label_tag" : "original" ,
200+ "max_workers" : 1 ,
201+ "max_batch_size" : 0 ,
202+ }
203+ )
204+ return
161205
162- app = MyApp (app_dir , studies , {"preload" : "false" , "models" : "spleen_deepedit_annotation" })
163- # train(app)
164- infer (app )
165-
166-
167- def infer (app ):
168- import json
169- import shutil
170-
171- res = app .infer (
172- request = {
173- "model" : "spleen_deepedit_annotation" ,
174- "image" : "image" ,
175- }
176- )
177-
178- print (json .dumps (res , indent = 2 ))
179- shutil .move (res ["label" ], os .path .join (app .studies , "test" ))
180- logger .info ("All Done!" )
181-
182-
183- def train (app ):
206+ # Train
184207 app .train (
185208 request = {
186- "model" : "spleen_deepedit_annotation" ,
187- "max_epochs" : 2 ,
209+ "model" : args .model ,
210+ "max_epochs" : 10 ,
211+ "dataset" : "Dataset" , # PersistentDataset, CacheDataset
212+ "train_batch_size" : 1 ,
213+ "val_batch_size" : 1 ,
188214 "multi_gpu" : False ,
189215 "val_split" : 0.1 ,
190- "val_interval" : 1 ,
191216 },
192217 )
193218
194219
195220if __name__ == "__main__" :
221+ # export PYTHONPATH=~/Projects/MONAILabel:`pwd`
222+ # python main.py
196223 main ()
0 commit comments