Skip to content

Commit 4e86665

Browse files
committed
fix rl model path issue
1 parent 0bfebe9 commit 4e86665

File tree

2 files changed

+119
-10
lines changed

2 files changed

+119
-10
lines changed

.github/workflows/agent_battle_pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ jobs:
9191
env:
9292
AGENT1_PATH: agents/${{ github.event.inputs.username1 }}/my_agent.py
9393
AGENT2_PATH: agents/${{ github.event.inputs.username2 }}/my_agent.py
94+
AI2_UNIQUE_MODEL: 1
9495
run: |
9596
echo "Running battle between:"
9697
echo "Agent 1: $AGENT1_PATH"

environment/agent.py

Lines changed: 118 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from PIL import Image, ImageSequence
1414
import matplotlib.pyplot as plt
1515

16-
import gdown, os, math, random, shutil, json
16+
import gdown, os, math, random, shutil, json, uuid
1717

1818
import numpy as np
1919
import torch
@@ -57,6 +57,7 @@ def __init__(
5757
# If no supplied file_path, load from gdown (optional file_path returned)
5858
if file_path is None:
5959
file_path = self._gdown()
60+
file_path = self._maybe_unique_model_path(file_path)
6061

6162
self.file_path: Optional[str] = file_path
6263
self.initialized = False
@@ -71,6 +72,8 @@ def get_env_info(self, env):
7172
self.action_space = self_env.action_space
7273
self.act_helper = self_env.act_helper
7374
self.env = env
75+
# Resolve model path (handles .zip duplication and fallbacks)
76+
self.file_path = self._resolve_file_path(self.file_path)
7477
self._initialize()
7578
self.initialized = True
7679

@@ -109,6 +112,120 @@ def _gdown(self) -> Optional[str]:
109112
"""
110113
return
111114

115+
def _resolve_file_path(self, file_path: Optional[str]) -> Optional[str]:
116+
"""
117+
Best-effort resolution of a provided model path.
118+
119+
- Converts to absolute path
120+
- Tries sibling variants with/without the `.zip` suffix
121+
- Looks in a few common folders (./, ./checkpoints, ./models, ./user/models)
122+
- Falls back to calling _gdown() if nothing is found
123+
124+
Returns a path that exists, or None if unresolved.
125+
"""
126+
try:
127+
if file_path is None:
128+
return None
129+
130+
# If a URL-like path is passed, skip straight to download
131+
is_url = isinstance(file_path, str) and (
132+
file_path.startswith('http://') or file_path.startswith('https://')
133+
or 'drive.google.com' in file_path
134+
)
135+
136+
candidates = []
137+
138+
if not is_url:
139+
base = file_path
140+
if not os.path.isabs(base):
141+
base = os.path.abspath(base)
142+
143+
candidates.append(base)
144+
if base.endswith('.zip'):
145+
candidates.append(base[:-4])
146+
else:
147+
candidates.append(base + '.zip')
148+
149+
# Probe a few common directories with both name variants
150+
cwd = os.getcwd()
151+
name = os.path.basename(base)
152+
probe_dirs = [
153+
'.',
154+
'checkpoints',
155+
'models',
156+
os.path.join('user', 'models'),
157+
]
158+
for d in probe_dirs:
159+
root = os.path.abspath(os.path.join(cwd, d))
160+
p1 = os.path.join(root, name)
161+
candidates.append(p1)
162+
if not name.endswith('.zip'):
163+
candidates.append(p1 + '.zip')
164+
165+
# Return the first existing candidate
166+
for c in candidates:
167+
if isinstance(c, str) and os.path.exists(c):
168+
return c
169+
170+
# Nothing found locally - try to download via _gdown
171+
downloaded = self._gdown()
172+
if isinstance(downloaded, str):
173+
resolved = downloaded if os.path.isabs(downloaded) else os.path.abspath(downloaded)
174+
if os.path.exists(resolved):
175+
# If a specific (non-existing) target was requested, copy to it
176+
if (file_path is not None and not is_url):
177+
target = file_path if os.path.isabs(file_path) else os.path.abspath(file_path)
178+
try:
179+
os.makedirs(os.path.dirname(target), exist_ok=True)
180+
shutil.copy2(resolved, target)
181+
return target
182+
except Exception as copy_e:
183+
warnings.warn(f"Failed to copy downloaded model to requested path: {copy_e}")
184+
# Optionally create a unique-named copy if requested by env
185+
env_toggle = os.environ.get('AI2_UNIQUE_MODEL', '0').lower() in ('1', 'true', 'yes')
186+
if env_toggle:
187+
base_dir = os.environ.get('AI2_MODEL_DIR', os.path.dirname(resolved))
188+
base_name = os.environ.get('AI2_MODEL_BASENAME', 'rl-model')
189+
unique = uuid.uuid4().hex[:8]
190+
target = os.path.abspath(os.path.join(base_dir, f"{base_name}-{unique}.zip"))
191+
try:
192+
os.makedirs(os.path.dirname(target), exist_ok=True)
193+
shutil.copy2(resolved, target)
194+
return target
195+
except Exception as e2:
196+
warnings.warn(f"Failed to create unique model copy: {e2}")
197+
return resolved
198+
except Exception as e:
199+
warnings.warn(f"Model path resolution failed: {e}")
200+
201+
return None
202+
203+
def _maybe_unique_model_path(self, path: Optional[str]) -> Optional[str]:
204+
"""
205+
Optionally duplicate the downloaded model to a uniquely named path when
206+
AI2_UNIQUE_MODEL is enabled. Returns the new path, or the original path
207+
on any failure or when disabled.
208+
"""
209+
try:
210+
if path is None:
211+
return None
212+
env_toggle = os.environ.get('AI2_UNIQUE_MODEL', '0').lower() in ('1', 'true', 'yes')
213+
if not env_toggle:
214+
return path
215+
abs_src = path if os.path.isabs(path) else os.path.abspath(path)
216+
if not os.path.exists(abs_src):
217+
return path
218+
base_dir = os.environ.get('AI2_MODEL_DIR', os.path.dirname(abs_src))
219+
base_name = os.environ.get('AI2_MODEL_BASENAME', 'rl-model')
220+
unique = uuid.uuid4().hex[:8]
221+
target = os.path.abspath(os.path.join(base_dir, f"{base_name}-{unique}.zip"))
222+
os.makedirs(os.path.dirname(target), exist_ok=True)
223+
shutil.copy2(abs_src, target)
224+
return target
225+
except Exception as e:
226+
warnings.warn(f"Unique model naming failed: {e}")
227+
return path
228+
112229

113230
# ### Agent Classes
114231

@@ -638,15 +755,6 @@ def run_match(agent_1: Agent | partial,
638755

639756
# Initialize agents
640757
if not agent_1.initialized: agent_1.get_env_info(env)
641-
642-
# rl_model_path = "rl-model.zip"
643-
# if os.path.exists(rl_model_path):
644-
# try:
645-
# os.remove(rl_model_path)
646-
# print(f"Removed {rl_model_path}")
647-
# except Exception as e:
648-
# print(f"Warning: Could not remove {rl_model_path}: {e}")
649-
650758
if not agent_2.initialized: agent_2.get_env_info(env)
651759
# 596, 336
652760
platform1 = env.objects["platform1"]

0 commit comments

Comments
 (0)