Skip to content

Commit 05d0498

Browse files
committed
fix - change runfiles
1 parent d41a0f6 commit 05d0498

File tree

4 files changed

+76
-275
lines changed

4 files changed

+76
-275
lines changed

rtp_llm/test/utils/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ py_binary(
2727
name = "gpu_lock",
2828
deps = [":torch"],
2929
main = "device_resource.py",
30-
srcs = ["device_resource.py"],
30+
srcs = [
31+
"device_resource.py",
32+
"jit_sys_path_setup.py",
33+
],
3134
)
3235

3336
py_library(

rtp_llm/test/utils/bootstrap_runner.py

Lines changed: 0 additions & 212 deletions
This file was deleted.

rtp_llm/test/utils/device_resource.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -128,23 +128,10 @@ def __exit__(self, *args: Any):
128128

129129
sys.exit(result.returncode)
130130
else:
131-
if "H20" in cuda_info[0]:
132-
# Setup JIT cache and create Bootstrap script
133-
try:
134-
from jit_sys_path_setup import setup_jit_cache_and_create_bootstrap
135-
136-
bootstrap_script = setup_jit_cache_and_create_bootstrap()
137-
except Exception as e:
138-
logging.warning(f"JIT setup failed: {e}, will run without Bootstrap")
139-
bootstrap_script = None
140-
141-
# Prepare test command
142-
if bootstrap_script:
143-
# Use Bootstrap wrapper
144-
test_command = [sys.executable, bootstrap_script] + sys.argv[1:]
145-
else:
146-
# Fallback to direct execution
147-
test_command = sys.argv[1:]
131+
from jit_sys_path_setup import setup_jit_cache
132+
133+
setup_jit_cache()
134+
148135
device_name, _ = cuda_info
149136
require_count = int(
150137
os.environ.get("WORLD_SIZE", os.environ.get("GPU_COUNT", "1"))
@@ -155,11 +142,6 @@ def __exit__(self, *args: Any):
155142
else:
156143
env_name = "CUDA_VISIBLE_DEVICES"
157144
os.environ[env_name] = ",".join(gpu_resource.gpu_ids)
158-
result = subprocess.run(test_command)
145+
result = subprocess.run(sys.argv[1:])
159146
logging.info("exitcode: %d", result.returncode)
160-
161-
# Cleanup Bootstrap script
162-
if bootstrap_script and os.path.exists(bootstrap_script):
163-
os.unlink(bootstrap_script)
164-
165147
sys.exit(result.returncode)

rtp_llm/test/utils/jit_sys_path_setup.py

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -161,23 +161,69 @@ def copy_package_with_lock(package_name, cache_dir):
161161
return None
162162

163163

164-
def setup_jit_cache_and_create_bootstrap(cache_dir=None, packages=None):
164+
def modify_bazel_wrapper_pythonpath(wrapper_path):
165165
"""
166-
Setup JIT package cache and create Bootstrap script.
166+
Modify Bazel-generated wrapper to inject _JIT_CACHE_PATHS at the beginning of PYTHONPATH.
167167
168168
Args:
169-
cache_dir: Cache directory path. Defaults to /home/yangchengjun.ycj/.cache
170-
packages: List of package names to copy. Defaults to ["flashinfer", "torch", "deep_gemm"]
171-
logger: Logger instance for output. If None, uses logging.info()
172-
173-
Returns:
174-
Bootstrap script path (str) or None if setup fails
169+
wrapper_path: Path to the Bazel-generated wrapper file
175170
"""
176-
import tempfile
171+
try:
172+
with open(wrapper_path, "r") as f:
173+
lines = f.readlines()
174+
# Find the line index where new_env['PYTHONPATH'] = python_path (line 479)
175+
target_line_idx = None
176+
for i, line in enumerate(lines):
177+
if "new_env['PYTHONPATH'] = python_path" in line:
178+
target_line_idx = i
179+
break
180+
if target_line_idx is None:
181+
logging.warning(
182+
f"[Package Setup] Could not find target line in wrapper: {wrapper_path}"
183+
)
184+
return False
185+
186+
# Create injection code to insert before line 479
187+
injection_lines = [
188+
" # Inject _JIT_CACHE_PATHS at the beginning of PYTHONPATH\n",
189+
" jit_cache_paths = os.environ.get('_JIT_CACHE_PATHS', '')\n",
190+
" if jit_cache_paths:\n",
191+
" jit_cache_entries = jit_cache_paths.split(os.pathsep)\n",
192+
" # Prepend cache paths to the beginning of python_path\n",
193+
" python_path = os.pathsep.join(jit_cache_entries) + os.pathsep + python_path\n",
194+
]
195+
196+
# Insert the code before the target line
197+
lines[target_line_idx:target_line_idx] = injection_lines
198+
199+
import stat
200+
201+
if os.path.exists(wrapper_path):
202+
# Make file writable
203+
current_permissions = os.stat(wrapper_path).st_mode
204+
os.chmod(wrapper_path, current_permissions | stat.S_IWRITE)
205+
206+
# Write back to file
207+
with open(wrapper_path, "w") as f:
208+
f.writelines(lines)
209+
210+
logging.info(f"[Package Setup] Modified Bazel wrapper: {wrapper_path}")
211+
logging.info(
212+
f"[Package Setup] Injected _JIT_CACHE_PATHS at the beginning of PYTHONPATH"
213+
)
214+
return True
215+
216+
except Exception as e:
217+
logging.warning(
218+
f"[Package Setup] Failed to modify Bazel wrapper {wrapper_path}: {e}"
219+
)
220+
return False
177221

222+
223+
def setup_jit_cache(cache_dir=None, packages=None):
178224
# Use defaults if not provided
179225
if cache_dir is None:
180-
cache_dir = "/home/yangchengjun.ycj/.cache"
226+
cache_dir = Path.home().as_posix() + "/.cache"
181227
if packages is None:
182228
packages = ["flashinfer", "torch", "deep_gemm"]
183229

@@ -201,32 +247,14 @@ def setup_jit_cache_and_create_bootstrap(cache_dir=None, packages=None):
201247
logging.info(
202248
f"[Package Setup] Set _JIT_CACHE_PATHS: {os.environ['_JIT_CACHE_PATHS']}"
203249
)
204-
205-
# Store current working directory for Bootstrap script to use
206-
os.environ["_JIT_ORIGINAL_CWD"] = os.getcwd()
207-
logging.info(
208-
f"[Package Setup] Set _JIT_ORIGINAL_CWD: {os.environ['_JIT_ORIGINAL_CWD']}"
209-
)
210-
211-
# Read Bootstrap runner template
212-
bootstrap_template_path = Path(__file__).parent / "bootstrap_runner.py"
213-
try:
214-
with open(bootstrap_template_path, "r") as f:
215-
bootstrap_code = f.read()
216-
except FileNotFoundError:
217-
logging.error(
218-
f"[Package Setup] ERROR: Bootstrap template not found: {bootstrap_template_path}"
219-
)
220-
return None
221-
222-
# Write bootstrap script to temporary file
223-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
224-
bootstrap_script = f.name
225-
f.write(bootstrap_code)
226-
227-
logging.info("=" * 80)
228-
logging.info(f"[Package Setup] Created bootstrap script: {bootstrap_script}")
229-
logging.info(f"[Package Setup] Bootstrap script ready for execution")
230-
logging.info("=" * 80)
231-
232-
return bootstrap_script
250+
runfiles_dir = os.environ.get("RUNFILES_DIR", None)
251+
test_binary = sys.argv[1]
252+
bazel_wrapper_path = os.path.join(runfiles_dir, "rtp_llm/" + test_binary)
253+
bazel_wrapper_path_new = bazel_wrapper_path + "_new"
254+
if os.path.exists(bazel_wrapper_path_new):
255+
os.remove(bazel_wrapper_path_new)
256+
logging.info(f"[Package Setup] Removed existing file: {bazel_wrapper_path_new}")
257+
shutil.copy2(bazel_wrapper_path, bazel_wrapper_path_new)
258+
logging.info(f"[Package Setup] Copied Bazel wrapper to: {bazel_wrapper_path_new}")
259+
modify_bazel_wrapper_pythonpath(bazel_wrapper_path_new)
260+
sys.argv[1] = test_binary + "_new"

0 commit comments

Comments
 (0)