1111from modelscope .hub .api import ModelScopeConfig
1212from tqdm import tqdm
1313
14- from .env import is_master
14+ from .env import is_last_rank , is_master
1515from .logger import get_logger
1616from .utils import check_json_format
1717
@@ -46,8 +46,20 @@ def write_to_jsonl(fpath: str, obj_list: List[Any], encoding: str = 'utf-8') ->
4646
4747class JsonlWriter :
4848
49- def __init__ (self , fpath : str , * , encoding : str = 'utf-8' , strict : bool = True , enable_async : bool = False ):
50- self .fpath = os .path .abspath (os .path .expanduser (fpath )) if is_master () else None
49+ def __init__ (self ,
50+ fpath : str ,
51+ * ,
52+ encoding : str = 'utf-8' ,
53+ strict : bool = True ,
54+ enable_async : bool = False ,
55+ write_on_rank : Literal ['master' , 'last' ] = 'master' ):
56+ if write_on_rank == 'master' :
57+ self .is_write_rank = is_master ()
58+ elif write_on_rank == 'last' :
59+ self .is_write_rank = is_last_rank ()
60+ else :
61+ raise ValueError (f"Invalid `write_on_rank`: { write_on_rank } , should be 'master' or 'last'" )
62+ self .fpath = os .path .abspath (os .path .expanduser (fpath )) if self .is_write_rank else None
5163 self .encoding = encoding
5264 self .strict = strict
5365 self .enable_async = enable_async
@@ -66,7 +78,7 @@ def _append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False):
6678 obj_list = [obj ]
6779 if gather_obj and dist .is_initialized ():
6880 obj_list = gather_object (obj_list )
69- if not is_master () :
81+ if not self . is_write_rank :
7082 return
7183 obj_list = check_json_format (obj_list )
7284 for i , _obj in enumerate (obj_list ):
@@ -85,7 +97,7 @@ def append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False):
8597 def _write_buffer (self , text : str ):
8698 if not text :
8799 return
88- assert is_master () , f'is_master() : { is_master () } '
100+ assert self . is_write_rank , f'self.is_write_rank : { self . is_write_rank } '
89101 try :
90102 os .makedirs (os .path .dirname (self .fpath ), exist_ok = True )
91103 with open (self .fpath , 'a' , encoding = self .encoding ) as f :
0 commit comments