1313from PIL import Image , ImageSequence
1414import matplotlib .pyplot as plt
1515
16- import gdown , os , math , random , shutil , json
16+ import gdown , os , math , random , shutil , json , uuid
1717
1818import numpy as np
1919import torch
@@ -57,6 +57,7 @@ def __init__(
5757 # If no supplied file_path, load from gdown (optional file_path returned)
5858 if file_path is None :
5959 file_path = self ._gdown ()
60+ file_path = self ._maybe_unique_model_path (file_path )
6061
6162 self .file_path : Optional [str ] = file_path
6263 self .initialized = False
@@ -71,6 +72,8 @@ def get_env_info(self, env):
7172 self .action_space = self_env .action_space
7273 self .act_helper = self_env .act_helper
7374 self .env = env
75+ # Resolve model path (handles .zip duplication and fallbacks)
76+ self .file_path = self ._resolve_file_path (self .file_path )
7477 self ._initialize ()
7578 self .initialized = True
7679
@@ -109,6 +112,120 @@ def _gdown(self) -> Optional[str]:
109112 """
110113 return
111114
115+ def _resolve_file_path (self , file_path : Optional [str ]) -> Optional [str ]:
116+ """
117+ Best-effort resolution of a provided model path.
118+
119+ - Converts to absolute path
120+ - Tries sibling variants with/without the `.zip` suffix
121+ - Looks in a few common folders (./, ./checkpoints, ./models, ./user/models)
122+ - Falls back to calling _gdown() if nothing is found
123+
124+ Returns a path that exists, or None if unresolved.
125+ """
126+ try :
127+ if file_path is None :
128+ return None
129+
130+ # If a URL-like path is passed, skip straight to download
131+ is_url = isinstance (file_path , str ) and (
132+ file_path .startswith ('http://' ) or file_path .startswith ('https://' )
133+ or 'drive.google.com' in file_path
134+ )
135+
136+ candidates = []
137+
138+ if not is_url :
139+ base = file_path
140+ if not os .path .isabs (base ):
141+ base = os .path .abspath (base )
142+
143+ candidates .append (base )
144+ if base .endswith ('.zip' ):
145+ candidates .append (base [:- 4 ])
146+ else :
147+ candidates .append (base + '.zip' )
148+
149+ # Probe a few common directories with both name variants
150+ cwd = os .getcwd ()
151+ name = os .path .basename (base )
152+ probe_dirs = [
153+ '.' ,
154+ 'checkpoints' ,
155+ 'models' ,
156+ os .path .join ('user' , 'models' ),
157+ ]
158+ for d in probe_dirs :
159+ root = os .path .abspath (os .path .join (cwd , d ))
160+ p1 = os .path .join (root , name )
161+ candidates .append (p1 )
162+ if not name .endswith ('.zip' ):
163+ candidates .append (p1 + '.zip' )
164+
165+ # Return the first existing candidate
166+ for c in candidates :
167+ if isinstance (c , str ) and os .path .exists (c ):
168+ return c
169+
170+ # Nothing found locally - try to download via _gdown
171+ downloaded = self ._gdown ()
172+ if isinstance (downloaded , str ):
173+ resolved = downloaded if os .path .isabs (downloaded ) else os .path .abspath (downloaded )
174+ if os .path .exists (resolved ):
175+ # If a specific (non-existing) target was requested, copy to it
176+ if (file_path is not None and not is_url ):
177+ target = file_path if os .path .isabs (file_path ) else os .path .abspath (file_path )
178+ try :
179+ os .makedirs (os .path .dirname (target ), exist_ok = True )
180+ shutil .copy2 (resolved , target )
181+ return target
182+ except Exception as copy_e :
183+ warnings .warn (f"Failed to copy downloaded model to requested path: { copy_e } " )
184+ # Optionally create a unique-named copy if requested by env
185+ env_toggle = os .environ .get ('AI2_UNIQUE_MODEL' , '0' ).lower () in ('1' , 'true' , 'yes' )
186+ if env_toggle :
187+ base_dir = os .environ .get ('AI2_MODEL_DIR' , os .path .dirname (resolved ))
188+ base_name = os .environ .get ('AI2_MODEL_BASENAME' , 'rl-model' )
189+ unique = uuid .uuid4 ().hex [:8 ]
190+ target = os .path .abspath (os .path .join (base_dir , f"{ base_name } -{ unique } .zip" ))
191+ try :
192+ os .makedirs (os .path .dirname (target ), exist_ok = True )
193+ shutil .copy2 (resolved , target )
194+ return target
195+ except Exception as e2 :
196+ warnings .warn (f"Failed to create unique model copy: { e2 } " )
197+ return resolved
198+ except Exception as e :
199+ warnings .warn (f"Model path resolution failed: { e } " )
200+
201+ return None
202+
203+ def _maybe_unique_model_path (self , path : Optional [str ]) -> Optional [str ]:
204+ """
205+ Optionally duplicate the downloaded model to a uniquely named path when
206+ AI2_UNIQUE_MODEL is enabled. Returns the new path, or the original path
207+ on any failure or when disabled.
208+ """
209+ try :
210+ if path is None :
211+ return None
212+ env_toggle = os .environ .get ('AI2_UNIQUE_MODEL' , '0' ).lower () in ('1' , 'true' , 'yes' )
213+ if not env_toggle :
214+ return path
215+ abs_src = path if os .path .isabs (path ) else os .path .abspath (path )
216+ if not os .path .exists (abs_src ):
217+ return path
218+ base_dir = os .environ .get ('AI2_MODEL_DIR' , os .path .dirname (abs_src ))
219+ base_name = os .environ .get ('AI2_MODEL_BASENAME' , 'rl-model' )
220+ unique = uuid .uuid4 ().hex [:8 ]
221+ target = os .path .abspath (os .path .join (base_dir , f"{ base_name } -{ unique } .zip" ))
222+ os .makedirs (os .path .dirname (target ), exist_ok = True )
223+ shutil .copy2 (abs_src , target )
224+ return target
225+ except Exception as e :
226+ warnings .warn (f"Unique model naming failed: { e } " )
227+ return path
228+
112229
113230# ### Agent Classes
114231
@@ -638,15 +755,6 @@ def run_match(agent_1: Agent | partial,
638755
639756 # Initialize agents
640757 if not agent_1 .initialized : agent_1 .get_env_info (env )
641-
642- # rl_model_path = "rl-model.zip"
643- # if os.path.exists(rl_model_path):
644- # try:
645- # os.remove(rl_model_path)
646- # print(f"Removed {rl_model_path}")
647- # except Exception as e:
648- # print(f"Warning: Could not remove {rl_model_path}: {e}")
649-
650758 if not agent_2 .initialized : agent_2 .get_env_info (env )
651759 # 596, 336
652760 platform1 = env .objects ["platform1" ]
0 commit comments