@@ -159,11 +159,16 @@ def validate_args(**kwargs) -> Dict[str, object]:
159159 return kwargs
160160
161161
162+ def to_path (x ):
163+ return Path (x ) if isinstance (x , str ) and x not in ALGOS else x
164+
165+
162166def datasail (
163167 techniques : Union [str , List [str ], Callable [..., List [str ]], Generator [str , None , None ]] = None ,
164168 inter : Optional [
165169 Union [str , Path , List [Tuple [str , str ]], Callable [..., List [str ]], Generator [str , None , None ]]
166170 ] = None ,
171+ output : Optional [Union [str , Path ]] = None ,
167172 max_sec : int = 100 ,
168173 verbose : str = "W" ,
169174 splits : List [float ] = None ,
@@ -200,6 +205,7 @@ def datasail(
200205 Args:
201206 techniques: List of techniques to split based on
202207 inter: Filepath to a TSV file storing interactions of the e-entities and f-entities.
208+ output: Output directory to store the results in.
203209 max_sec: Maximal number of seconds to take for optimizing a found solution.
204210 verbose: Verbosity level for logging.
205211 splits: List of splits, have to add up to one, otherwise scaled accordingly.
@@ -233,11 +239,8 @@ def datasail(
233239 Three dictionaries mapping techniques to another dictionary. The inner dictionary maps input id to their splits.
234240 """
235241
236- def to_path (x ):
237- return Path (x ) if isinstance (x , str ) and x not in ALGOS else x
238-
239242 kwargs = validate_args (
240- output = None , techniques = techniques , inter = to_path (inter ), max_sec = max_sec , verbosity = verbose ,
243+ output = to_path ( output ) , techniques = techniques , inter = to_path (inter ), max_sec = max_sec , verbosity = verbose ,
241244 splits = splits , names = names , delta = delta , epsilon = epsilon , runs = runs , solver = solver , cache = cache ,
242245 cache_dir = to_path (cache_dir ), linkage = linkage , e_type = e_type , e_data = to_path (e_data ),
243246 e_weights = to_path (e_weights ), e_strat = to_path (e_strat ), e_sim = to_path (e_sim ), e_dist = to_path (e_dist ),
@@ -257,5 +260,9 @@ def sail(args=None, **kwargs) -> None:
257260 kwargs = parse_datasail_args (args or sys .argv [1 :])
258261 kwargs = {key : (kwargs [key ] if key in kwargs else val ) for key , val in DEFAULT_KWARGS .items ()}
259262 kwargs [KW_CLI ] = True
263+ for kwarg in [KW_OUTDIR , KW_INTER , KW_CACHE_DIR , KW_E_DATA , KW_E_WEIGHTS , KW_E_STRAT ,
264+ KW_E_SIM , KW_E_DIST , KW_F_DATA , KW_F_WEIGHTS , KW_F_STRAT , KW_F_SIM , KW_F_DIST ]:
265+ if kwarg in kwargs :
266+ kwargs [kwarg ] = to_path (kwargs [kwarg ])
260267 kwargs = validate_args (** kwargs )
261268 datasail_main (** kwargs )
0 commit comments