1616import tempfile
1717import time
1818
19+ import numpy as np
1920import requests
2021from PIL import Image
2122from requests .auth import HTTPBasicAuth
@@ -39,6 +40,7 @@ def __init__(
3940 image_quality = 70 ,
4041 labels = None ,
4142 normalize_label = True ,
43+ segment_size = 1 ,
4244 ** kwargs ,
4345 ):
4446 default_labels = [
@@ -47,15 +49,17 @@ def __init__(
4749 {"name" : "OutBody" , "attributes" : [], "color" : "#0000ff" },
4850 ]
4951 labels = labels if labels else default_labels
50- labels = json .loads (labels ) if isinstance (labels , str ) else self . json ()
52+ labels = json .loads (labels ) if isinstance (labels , str ) else labels
5153
5254 self .api_url = api_url .rstrip ("/" ).strip ()
5355 self .auth = HTTPBasicAuth (username , password ) if username else None
5456 self .project = project
5557 self .task_prefix = task_prefix
5658 self .image_quality = image_quality
5759 self .labels = labels
60+ self .label_map = {l ["name" ]: idx for idx , l in enumerate (labels , start = 1 )}
5861 self .normalize_label = normalize_label
62+ self .segment_size = segment_size
5963
6064 logger .info (f"CVAT:: API URL: { api_url } " )
6165 logger .info (f"CVAT:: UserName: { username } " )
@@ -65,6 +69,7 @@ def __init__(
6569 logger .info (f"CVAT:: Image Quality: { image_quality } " )
6670 logger .info (f"CVAT:: Labels: { labels } " )
6771 logger .info (f"CVAT:: Normalize Label: { normalize_label } " )
72+ logger .info (f"CVAT:: Segment Size: { normalize_label } " )
6873
6974 super ().__init__ (datastore_path = datastore_path , ** kwargs )
7075
@@ -117,7 +122,10 @@ def get_cvat_task_id(self, project_id, create):
117122 task_name = f"{ self .task_prefix } _{ version } "
118123 logger .info (f"Creating new CVAT Task: { task_name } ; project: { self .project } " )
119124
120- body = {"name" : task_name , "labels" : [], "project_id" : project_id , "subset" : "Train" , "segment_size" : 1 }
125+ body = {"name" : task_name , "labels" : [], "project_id" : project_id , "subset" : "Train" }
126+ if self .segment_size :
127+ body ["segment_size" ] = self .segment_size
128+
121129 task = requests .post (f"{ self .api_url } /api/tasks" , auth = self .auth , json = body ).json ()
122130 logger .debug (task )
123131 task_id = task ["id" ]
@@ -157,6 +165,18 @@ def trigger_automation(self, function):
157165 r = requests .post (f"{ self .api_url } /api/lambda/requests?org=" , json = body , auth = self .auth ).json ()
158166 logger .info (r )
159167
168+ def _load_labelmap_txt (self , file ):
169+ labelmap = {}
170+ if os .path .exists (file ):
171+ with open (file ) as f :
172+ for line in f .readlines ():
173+ if line and not line .startswith ("#" ):
174+ fields = line .split (":" )
175+ name = fields [0 ]
176+ rgb = tuple (int (c ) for c in fields [1 ].split ("," ))
177+ labelmap [name ] = rgb
178+ return labelmap
179+
160180 def download_from_cvat (self , max_retry_count = 5 , retry_wait_time = 10 ):
161181 if self .task_status () != "completed" :
162182 logger .info ("No Tasks exists with completed status to refresh/download the final labels" )
@@ -190,10 +210,19 @@ def download_from_cvat(self, max_retry_count=5, retry_wait_time=10):
190210
191211 dest = os .path .join (final_labels , f )
192212 if self .normalize_label :
193- Image .open (label ).convert ("L" ).point (lambda x : 0 if x < 128 else 255 , "1" ).save (dest )
213+ img = np .array (Image .open (label ))
214+ mask = np .zeros_like (img )
215+
216+ labelmap = self ._load_labelmap_txt (os .path .join (tmp_folder , "labelmap.txt" ))
217+ for name , color in labelmap .items ():
218+ if name in self .label_map :
219+ idx = self .label_map .get (name )
220+ mask [np .all (img == color , axis = - 1 )] = idx
221+ Image .fromarray (mask [:, :, 0 ]).save (dest ) # single channel
222+ logger .info (f"Copy Final Label: { label } to { dest } ; unique: { np .unique (mask )} " )
194223 else :
195224 Image .open (label ).save (dest )
196- logger .info (f"Copy Final Label: { label } to { dest } " )
225+ logger .info (f"Copy Final Label: { label } to { dest } " )
197226
198227 # Rename after consuming/downloading the labels
199228 patch_url = f"{ self .api_url } /api/tasks/{ task_id } "
@@ -206,3 +235,51 @@ def download_from_cvat(self, max_retry_count=5, retry_wait_time=10):
206235 logger .error (f"{ retry } => Failed to download..." )
207236 retry_count = retry_count + 1
208237 return None
238+
239+
240+ def main ():
241+ from pathlib import Path
242+
243+ from monailabel .config import settings
244+
245+ settings .MONAI_LABEL_DATASTORE_AUTO_RELOAD = False
246+ settings .MONAI_LABEL_DATASTORE_FILE_EXT = ["*.png" , "*.jpg" , "*.jpeg" , ".xml" ]
247+ settings .MONAI_LABEL_DATASTORE = "cvat"
248+ settings .MONAI_LABEL_DATASTORE_URL = "http://10.117.19.88:8080"
249+ settings .MONAI_LABEL_DATASTORE_USERNAME = "sachi"
250+ settings .MONAI_LABEL_DATASTORE_PASSWORD = "sachi"
251+
252+ os .putenv ("MASTER_ADDR" , "127.0.0.1" )
253+ os .putenv ("MASTER_PORT" , "1234" )
254+
255+ logging .basicConfig (
256+ level = logging .INFO ,
257+ format = "[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s" ,
258+ datefmt = "%Y-%m-%d %H:%M:%S" ,
259+ force = True ,
260+ )
261+
262+ home = str (Path .home ())
263+ studies = f"{ home } /Dataset/picked/all"
264+
265+ ds = CVATDatastore (
266+ datastore_path = studies ,
267+ api_url = settings .MONAI_LABEL_DATASTORE_URL ,
268+ username = settings .MONAI_LABEL_DATASTORE_USERNAME ,
269+ password = settings .MONAI_LABEL_DATASTORE_PASSWORD ,
270+ project = "MONAILabel" ,
271+ task_prefix = "ActiveLearning_Iteration" ,
272+ image_quality = 70 ,
273+ labels = None ,
274+ normalize_label = True ,
275+ segment_size = 0 ,
276+ extensions = settings .MONAI_LABEL_DATASTORE_FILE_EXT ,
277+ auto_reload = settings .MONAI_LABEL_DATASTORE_AUTO_RELOAD ,
278+ )
279+ ds .download_from_cvat ()
280+
281+ # studies = f"{home}/Dataset/Holoscan/flattened/images"
282+
283+
284+ if __name__ == "__main__" :
285+ main ()
0 commit comments