Skip to content

Commit 44265e4

Browse files
feat: pass env_vars in engine
1 parent a6d32d7 commit 44265e4

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

graphgen/engine.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import os
12
import inspect
23
import logging
34
from collections import defaultdict, deque
45
from functools import wraps
56
from typing import Any, Callable, Dict, List, Set
7+
from dotenv import load_dotenv
68

79
import ray
810
import ray.data
@@ -12,6 +14,7 @@
1214
from graphgen.utils import logger
1315
from graphgen.common import init_llm
1416

17+
load_dotenv()
1518

1619
class Engine:
1720
def __init__(
@@ -31,6 +34,16 @@ def __init__(
3134
ctx.enable_tensor_extension_casting = False
3235
ctx._metrics_export_port = 0 # Disable metrics exporter to avoid RpcError
3336

37+
all_env_vars = os.environ.copy()
38+
if "runtime_env" not in ray_init_kwargs:
39+
ray_init_kwargs["runtime_env"] = {}
40+
41+
existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
42+
ray_init_kwargs["runtime_env"]["env_vars"] = {
43+
**all_env_vars,
44+
**existing_env_vars
45+
}
46+
3447
if not ray.is_initialized():
3548
context = ray.init(
3649
ignore_reinit_error=True,

graphgen/run.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import ray
88
import yaml
9-
from dotenv import load_dotenv
109
from ray.data.block import Block
1110
from ray.data.datasource.filename_provider import FilenameProvider
1211

@@ -16,8 +15,6 @@
1615

1716
sys_path = os.path.abspath(os.path.dirname(__file__))
1817

19-
load_dotenv()
20-
2118

2219
def set_working_dir(folder):
2320
os.makedirs(folder, exist_ok=True)

0 commit comments

Comments
 (0)