Skip to content

Commit 577e7e1

Browse files
authored
feature(chat ui): inference provider tab first pass (#14)
1 parent dff8210 commit 577e7e1

File tree

6 files changed

+3113
-110
lines changed

6 files changed

+3113
-110
lines changed

brev/welcome-ui/server.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,45 @@
3131
POLICY_FILE = os.path.join(SANDBOX_DIR, "policy.yaml")
3232

3333
LOG_FILE = "/tmp/nemoclaw-sandbox-create.log"
34+
PROVIDER_CONFIG_CACHE = "/tmp/nemoclaw-provider-config-cache.json"
3435
BREV_ENV_ID = os.environ.get("BREV_ENV_ID", "")
3536
_detected_brev_id = ""
3637

3738
SANDBOX_PORT = 18789
3839

40+
_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]")
41+
42+
def _strip_ansi(text: str) -> str:
43+
return _ANSI_RE.sub("", text)
44+
45+
def _read_config_cache() -> dict:
46+
try:
47+
with open(PROVIDER_CONFIG_CACHE) as f:
48+
return json.load(f)
49+
except (FileNotFoundError, json.JSONDecodeError):
50+
return {}
51+
52+
53+
def _write_config_cache(cache: dict) -> None:
54+
try:
55+
with open(PROVIDER_CONFIG_CACHE, "w") as f:
56+
json.dump(cache, f)
57+
except OSError:
58+
pass
59+
60+
61+
def _cache_provider_config(name: str, config: dict) -> None:
62+
cache = _read_config_cache()
63+
cache[name] = config
64+
_write_config_cache(cache)
65+
66+
67+
def _remove_cached_provider(name: str) -> None:
68+
cache = _read_config_cache()
69+
cache.pop(name, None)
70+
_write_config_cache(cache)
71+
72+
3973
_sandbox_lock = threading.Lock()
4074
_sandbox_state = {
4175
"status": "idle", # idle | creating | running | error
@@ -93,6 +127,9 @@ def _run_inject_key(key: str, key_hash: str) -> None:
93127
return
94128

95129
_inject_log(f"step 3/3: SUCCESS — provider nvidia-inference updated")
130+
_cache_provider_config("nvidia-inference", {
131+
"OPENAI_BASE_URL": "https://inference-api.nvidia.com/v1",
132+
})
96133
with _inject_key_lock:
97134
_inject_key_state["status"] = "done"
98135
_inject_key_state["error"] = None
@@ -455,6 +492,20 @@ def _route(self):
455492
if path == "/api/inject-key" and self.command == "POST":
456493
return self._handle_inject_key()
457494

495+
if path == "/api/providers" and self.command == "GET":
496+
return self._handle_providers_list()
497+
if path == "/api/providers" and self.command == "POST":
498+
return self._handle_provider_create()
499+
if re.match(r"^/api/providers/[\w-]+$", path) and self.command == "PUT":
500+
return self._handle_provider_update(path.split("/")[-1])
501+
if re.match(r"^/api/providers/[\w-]+$", path) and self.command == "DELETE":
502+
return self._handle_provider_delete(path.split("/")[-1])
503+
504+
if path == "/api/cluster-inference" and self.command == "GET":
505+
return self._handle_cluster_inference_get()
506+
if path == "/api/cluster-inference" and self.command == "POST":
507+
return self._handle_cluster_inference_set()
508+
458509
if _sandbox_ready():
459510
return self._proxy_to_sandbox()
460511

@@ -658,6 +709,234 @@ def _handle_inject_key(self):
658709

659710
return self._json_response(202, {"ok": True, "started": True})
660711

712+
# -- Provider CRUD --------------------------------------------------
713+
714+
@staticmethod
715+
def _parse_provider_detail(stdout: str) -> dict | None:
716+
"""Parse the text output of ``nemoclaw provider get <name>``."""
717+
info: dict = {}
718+
for line in stdout.splitlines():
719+
line = _strip_ansi(line).strip()
720+
if line.startswith("Id:"):
721+
info["id"] = line.split(":", 1)[1].strip()
722+
elif line.startswith("Name:"):
723+
info["name"] = line.split(":", 1)[1].strip()
724+
elif line.startswith("Type:"):
725+
info["type"] = line.split(":", 1)[1].strip()
726+
elif line.startswith("Credential keys:"):
727+
raw = line.split(":", 1)[1].strip()
728+
info["credentialKeys"] = (
729+
[k.strip() for k in raw.split(",") if k.strip()]
730+
if raw and raw != "<none>" else []
731+
)
732+
elif line.startswith("Config keys:"):
733+
raw = line.split(":", 1)[1].strip()
734+
info["configKeys"] = (
735+
[k.strip() for k in raw.split(",") if k.strip()]
736+
if raw and raw != "<none>" else []
737+
)
738+
return info if "name" in info else None
739+
740+
def _handle_providers_list(self):
741+
try:
742+
result = subprocess.run(
743+
["nemoclaw", "provider", "list", "--names"],
744+
capture_output=True, text=True, timeout=30,
745+
)
746+
if result.returncode != 0:
747+
return self._json_response(502, {
748+
"ok": False,
749+
"error": (result.stderr or result.stdout or "provider list failed").strip(),
750+
})
751+
names = [n.strip() for n in result.stdout.strip().splitlines() if n.strip()]
752+
except Exception as exc:
753+
return self._json_response(502, {"ok": False, "error": str(exc)})
754+
755+
providers = []
756+
config_cache = _read_config_cache()
757+
for name in names:
758+
try:
759+
detail = subprocess.run(
760+
["nemoclaw", "provider", "get", name],
761+
capture_output=True, text=True, timeout=30,
762+
)
763+
if detail.returncode == 0:
764+
parsed = self._parse_provider_detail(detail.stdout)
765+
if parsed:
766+
cached = config_cache.get(name, {})
767+
if cached:
768+
parsed["configValues"] = cached
769+
providers.append(parsed)
770+
except Exception:
771+
pass
772+
773+
return self._json_response(200, {"ok": True, "providers": providers})
774+
775+
def _read_json_body(self) -> dict | None:
776+
content_length = int(self.headers.get("Content-Length", 0))
777+
if content_length == 0:
778+
return None
779+
raw = self.rfile.read(content_length).decode("utf-8", errors="replace")
780+
try:
781+
return json.loads(raw)
782+
except json.JSONDecodeError:
783+
return None
784+
785+
def _handle_provider_create(self):
786+
data = self._read_json_body()
787+
if not data:
788+
return self._json_response(400, {"ok": False, "error": "invalid or empty JSON body"})
789+
790+
name = data.get("name", "").strip()
791+
ptype = data.get("type", "").strip()
792+
if not name or not ptype:
793+
return self._json_response(400, {"ok": False, "error": "name and type are required"})
794+
795+
cmd = ["nemoclaw", "provider", "create", "--name", name, "--type", ptype]
796+
creds = data.get("credentials", {})
797+
configs = data.get("config", {})
798+
if not creds:
799+
cmd += ["--credential", "PLACEHOLDER=unused"]
800+
for k, v in creds.items():
801+
cmd += ["--credential", f"{k}={v}"]
802+
for k, v in configs.items():
803+
cmd += ["--config", f"{k}={v}"]
804+
805+
try:
806+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
807+
if result.returncode != 0:
808+
err = (result.stderr or result.stdout or "create failed").strip()
809+
return self._json_response(400, {"ok": False, "error": err})
810+
if configs:
811+
_cache_provider_config(name, configs)
812+
return self._json_response(200, {"ok": True})
813+
except Exception as exc:
814+
return self._json_response(502, {"ok": False, "error": str(exc)})
815+
816+
def _handle_provider_update(self, name: str):
817+
data = self._read_json_body()
818+
if not data:
819+
return self._json_response(400, {"ok": False, "error": "invalid or empty JSON body"})
820+
821+
ptype = data.get("type", "").strip()
822+
if not ptype:
823+
return self._json_response(400, {"ok": False, "error": "type is required"})
824+
825+
cmd = ["nemoclaw", "provider", "update", name, "--type", ptype]
826+
for k, v in data.get("credentials", {}).items():
827+
cmd += ["--credential", f"{k}={v}"]
828+
configs = data.get("config", {})
829+
for k, v in configs.items():
830+
cmd += ["--config", f"{k}={v}"]
831+
832+
try:
833+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
834+
if result.returncode != 0:
835+
err = (result.stderr or result.stdout or "update failed").strip()
836+
return self._json_response(400, {"ok": False, "error": err})
837+
if configs:
838+
_cache_provider_config(name, configs)
839+
return self._json_response(200, {"ok": True})
840+
except Exception as exc:
841+
return self._json_response(502, {"ok": False, "error": str(exc)})
842+
843+
def _handle_provider_delete(self, name: str):
844+
try:
845+
result = subprocess.run(
846+
["nemoclaw", "provider", "delete", name],
847+
capture_output=True, text=True, timeout=30,
848+
)
849+
if result.returncode != 0:
850+
err = (result.stderr or result.stdout or "delete failed").strip()
851+
return self._json_response(400, {"ok": False, "error": err})
852+
_remove_cached_provider(name)
853+
return self._json_response(200, {"ok": True})
854+
except Exception as exc:
855+
return self._json_response(502, {"ok": False, "error": str(exc)})
856+
857+
# -- GET /api/cluster-inference ------------------------------------
858+
859+
@staticmethod
860+
def _parse_cluster_inference(stdout: str) -> dict | None:
861+
"""Parse ``nemoclaw cluster inference get/set`` output."""
862+
fields: dict[str, str] = {}
863+
for line in stdout.splitlines():
864+
stripped = _strip_ansi(line).strip()
865+
for key in ("Provider:", "Model:", "Version:"):
866+
if stripped.startswith(key):
867+
fields[key.rstrip(":")] = stripped[len(key):].strip()
868+
if "Provider" not in fields:
869+
return None
870+
version = 0
871+
try:
872+
version = int(fields.get("Version", "0"))
873+
except ValueError:
874+
pass
875+
return {
876+
"providerName": fields["Provider"],
877+
"modelId": fields.get("Model", ""),
878+
"version": version,
879+
}
880+
881+
def _handle_cluster_inference_get(self):
882+
try:
883+
result = subprocess.run(
884+
["nemoclaw", "cluster", "inference", "get"],
885+
capture_output=True, text=True, timeout=30,
886+
)
887+
if result.returncode != 0:
888+
stderr = (result.stderr or "").strip()
889+
if "not configured" in stderr.lower() or "not found" in stderr.lower():
890+
return self._json_response(200, {
891+
"ok": True,
892+
"providerName": None,
893+
"modelId": "",
894+
"version": 0,
895+
})
896+
err = stderr or (result.stdout or "get failed").strip()
897+
return self._json_response(400, {"ok": False, "error": err})
898+
parsed = self._parse_cluster_inference(result.stdout)
899+
if not parsed:
900+
return self._json_response(200, {
901+
"ok": True,
902+
"providerName": None,
903+
"modelId": "",
904+
"version": 0,
905+
})
906+
return self._json_response(200, {"ok": True, **parsed})
907+
except Exception as exc:
908+
return self._json_response(502, {"ok": False, "error": str(exc)})
909+
910+
# -- POST /api/cluster-inference -----------------------------------
911+
912+
def _handle_cluster_inference_set(self):
913+
body = self._read_json_body()
914+
if body is None:
915+
return self._json_response(400, {"ok": False, "error": "invalid JSON body"})
916+
provider_name = (body.get("providerName") or "").strip()
917+
model_id = (body.get("modelId") or "").strip()
918+
if not provider_name:
919+
return self._json_response(400, {"ok": False, "error": "providerName is required"})
920+
if not model_id:
921+
return self._json_response(400, {"ok": False, "error": "modelId is required"})
922+
try:
923+
result = subprocess.run(
924+
["nemoclaw", "cluster", "inference", "set",
925+
"--provider", provider_name,
926+
"--model", model_id],
927+
capture_output=True, text=True, timeout=30,
928+
)
929+
if result.returncode != 0:
930+
err = (result.stderr or result.stdout or "set failed").strip()
931+
return self._json_response(400, {"ok": False, "error": err})
932+
parsed = self._parse_cluster_inference(result.stdout)
933+
resp = {"ok": True}
934+
if parsed:
935+
resp.update(parsed)
936+
return self._json_response(200, resp)
937+
except Exception as exc:
938+
return self._json_response(502, {"ok": False, "error": str(exc)})
939+
661940
# -- GET /api/sandbox-status ----------------------------------------
662941

663942
def _handle_sandbox_status(self):
@@ -716,7 +995,20 @@ def log_message(self, fmt, *args):
716995
sys.stderr.write(f"[welcome-ui] {fmt % args}\n")
717996

718997

998+
def _bootstrap_config_cache() -> None:
999+
"""Seed the config cache for providers created before caching existed."""
1000+
if os.path.isfile(PROVIDER_CONFIG_CACHE):
1001+
return
1002+
_write_config_cache({
1003+
"nvidia-inference": {
1004+
"OPENAI_BASE_URL": "https://inference-api.nvidia.com/v1",
1005+
},
1006+
})
1007+
sys.stderr.write("[welcome-ui] Bootstrapped provider config cache\n")
1008+
1009+
7191010
if __name__ == "__main__":
1011+
_bootstrap_config_cache()
7201012
server = http.server.ThreadingHTTPServer(("", PORT), Handler)
7211013
print(f"NemoClaw Welcome UI → http://localhost:{PORT}")
7221014
server.serve_forever()

0 commit comments

Comments
 (0)