Skip to content

Commit 111909f

Browse files
kudroma404Deniskorerhohndorf
authored
Feat/gateway (#7)
* feat: gateway blocking pipeline * feat: spz auto update * Implement SPZLoader * feat: unblocking operator * fix: correct name for parallel generations * refactor: remove unnecessary print * fix: version fix * fix import issue in dependencies.py --------- Co-authored-by: Denis Avvakumov <denisavvakumov@gmail.com> Co-authored-by: Ruben Hohndorf <ruben.hohndorf@gmail.com>
1 parent a8b596e commit 111909f

File tree

15 files changed

+591
-98
lines changed

15 files changed

+591
-98
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
__pycache__/
33
*.py[cod]
44
*$py.class
5+
.env
56

67
# C extensions
78
*.so
@@ -131,4 +132,5 @@ dmypy.json
131132
# Vim stuff
132133
*.sw*
133134

134-
*.blend1
135+
*.blend1
136+
spz_version.txt

fourofour_3d_gen/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
def register():
88
preferences.register()
99
modules.append(preferences)
10+
1011
if dependencies.installed():
1112
from . import props, ui, ops
1213

@@ -17,6 +18,17 @@ def register():
1718
modules.append(ui)
1819
modules.append(props)
1920

21+
try:
22+
from .spz_loader import init_spz
23+
from pathlib import Path
24+
25+
pkg_dir = Path(__file__).resolve().parent
26+
print(f"Initializing SPZ with library path: {pkg_dir}")
27+
init_spz(str(pkg_dir))
28+
except Exception as e:
29+
# Report failure to load native SPZ library to console only
30+
print(f"SPZ initialization failed: {e}")
31+
2032

2133
def unregister():
2234
for m in modules:

fourofour_3d_gen/blender_manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
schema_version = "1.0.0"
22

33
id = "fourofour_3d_gen"
4-
version = "0.9.0"
4+
version = "0.10.0"
55
name = "404 3D Generator"
66
tagline = "404 3D Generator Extension"
77
maintainer = "Ruben Hohndorf <ruben@atlas.design>"

fourofour_3d_gen/client.py

Lines changed: 73 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,80 @@
22
import base64
33
import json
44
import tempfile
5-
import websocket
6-
5+
import time
6+
from .spz_loader import get_spz
77

88
from .protocol import Auth, PromptData, TaskStatus, TaskUpdate
9+
from .gateway.gateway_api import GatewayApi
10+
from .gateway.gateway_task import GatewayTask, GatewayTaskStatusResponse, GatewayTaskStatus
11+
12+
13+
_GATEWAY_STATUS_CHECK_INTERVAL_SEC: int = 5
14+
_GATEWAY_TASK_TIMEOUT_SEC: int = 10 * 60 # 10 minutes
15+
16+
17+
class GatewayErrorBase(Exception):
18+
pass
19+
20+
21+
class GatewayTimeoutError(GatewayErrorBase):
22+
pass
23+
24+
25+
class GatewayFailureError(GatewayErrorBase):
26+
pass
27+
28+
29+
class Client:
30+
31+
def __init__(self) -> None:
32+
self._task: GatewayTask | None = None
33+
self._task_start_time: float | None = None
34+
gateway_url = bpy.context.preferences.addons[__package__].preferences.url
35+
gateway_api_key = bpy.context.preferences.addons[__package__].preferences.token
36+
self._gateway_api: GatewayApi = GatewayApi(gateway_url=gateway_url, gateway_api_key=gateway_api_key)
37+
38+
@property
39+
def task_id(self) -> str | None:
40+
return self._task.id if self._task else None
41+
42+
@property
43+
def prompt(self) -> str | None:
44+
return self._task.prompt if self._task else None
45+
46+
def request_model(self, prompt: str) -> None:
47+
task = self._gateway_api.add_task(text_prompt=prompt)
48+
print(f"Task added: {task.id}")
49+
self._task = task
50+
self._task_start_time = time.time()
51+
52+
def get_result(self) -> str | None:
53+
if self._task is None or self._task_start_time is None:
54+
return None
55+
56+
# Check for timeout
57+
if time.time() - self._task_start_time > _GATEWAY_TASK_TIMEOUT_SEC:
58+
raise GatewayTimeoutError("Gateway timeout error")
59+
60+
task_status_response = self._gateway_api.get_status(task=self._task)
61+
task_status = task_status_response.status
62+
63+
if task_status not in [GatewayTaskStatus.SUCCESS, GatewayTaskStatus.FAILURE]:
64+
return None
65+
66+
if task_status == GatewayTaskStatus.FAILURE:
67+
raise GatewayFailureError(f"Gateway failure error")
68+
69+
if task_status == GatewayTaskStatus.SUCCESS:
70+
spz_data = self._gateway_api.get_result(task=self._task)
71+
print(f"Received result for task: {self._task.id}")
972

73+
loader = get_spz()
74+
with tempfile.NamedTemporaryFile(delete=False, suffix=".ply") as temp_file:
75+
ply_data = loader.decompress(spz_data, include_normals=False)
76+
temp_file.write(ply_data)
77+
filepath = temp_file.name
78+
print(f"Saved result to: {filepath}")
79+
return filepath
1080

11-
def request_model(prompt: str) -> tuple[None, None] | tuple[str, str]:
12-
url = bpy.context.preferences.addons[__package__].preferences.url
13-
api_key = bpy.context.preferences.addons[__package__].preferences.token
14-
filepath = None
15-
winner_hotkey = None
16-
17-
def on_message(ws, message):
18-
nonlocal filepath, winner_hotkey
19-
update = TaskUpdate(**json.loads(message))
20-
if update.status == TaskStatus.STARTED:
21-
print("Task started")
22-
elif update.status == TaskStatus.FIRST_RESULTS:
23-
score = update.results.score if update.results else None
24-
assets = update.results.assets or "" if update.results else ""
25-
print(f"First results. Score: {score}. Size: {len(assets)}")
26-
elif update.status == TaskStatus.BEST_RESULTS:
27-
score = update.results.score if update.results else None
28-
assets = update.results.assets or "" if update.results else ""
29-
print(f"Best results. Score: {score}. Size: {len(assets)}")
30-
print(f"Stats: {update.statistics}")
31-
32-
if assets:
33-
with tempfile.NamedTemporaryFile(delete=False, suffix=".ply") as temp_file:
34-
temp_file.write(base64.b64decode(assets.encode("utf-8")))
35-
filepath = temp_file.name
36-
ws.close()
37-
38-
def on_error(ws, error):
39-
print(f"WebSocket connection error: {error}")
40-
41-
def on_close(ws, close_status_code, close_msg):
42-
print(f"WebSocket connection closed: {close_status_code} {close_msg}")
43-
44-
def on_open(ws):
45-
auth_data = Auth(api_key=api_key).dict()
46-
prompt_data = PromptData(prompt=prompt, send_first_results=True).dict()
47-
ws.send(json.dumps(auth_data))
48-
ws.send(json.dumps(prompt_data))
49-
50-
ws = websocket.WebSocketApp(url, on_message=on_message, on_error=on_error, on_close=on_close)
51-
ws.on_open = on_open
52-
ws.run_forever()
53-
54-
return (filepath, winner_hotkey)
81+
return None

fourofour_3d_gen/dependencies.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ def installed() -> bool:
88
try:
99
import numpy
1010
import pydantic
11-
import websocket
1211
import mixpanel
1312

1413
return True
@@ -44,15 +43,20 @@ def install_dependencies_from_requirements():
4443
env=env_var,
4544
)
4645

46+
def update_spz():
47+
from .spz_updater import SPZUpdater
48+
49+
if SPZUpdater.need_update():
50+
SPZUpdater.update()
51+
4752

4853
def install() -> None:
4954
install_pip()
50-
5155
install_dependencies_from_requirements()
5256

53-
5457
if __name__ == "__main__":
5558
if installed():
5659
print("dependencies installed")
5760
else:
5861
install()
62+
update_spz()
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import json
2+
import requests
3+
4+
from typing import Any, cast
5+
from urllib.parse import urlencode
6+
7+
from .gateway_routes import GatewayRoutes
8+
from .gateway_task import GatewayTask, GatewayTaskStatusResponse
9+
10+
11+
class GatewayErrorBase(Exception):
12+
pass
13+
14+
15+
class GatewayAddTaskError(GatewayErrorBase):
16+
pass
17+
18+
19+
class GatewayGetStatusError(GatewayErrorBase):
20+
pass
21+
22+
23+
class GatewayGetResultError(GatewayErrorBase):
24+
pass
25+
26+
27+
class GatewayNoAttachmentError(GatewayErrorBase):
28+
pass
29+
30+
31+
class GatewayApi:
32+
"""API client for interacting with gateway."""
33+
34+
def __init__(self, *, gateway_url: str, gateway_api_key: str) -> None:
35+
self._http2_client = requests.Session()
36+
"""HTTP3 client for interacting with gateway."""
37+
self._gateway_url = gateway_url
38+
"""URL of the gateway."""
39+
self._gateway_api_key = gateway_api_key
40+
"""API key of the gateway that provides no rate limit for discord bot."""
41+
42+
def add_task(self, *, text_prompt: str) -> GatewayTask:
43+
"""Adds a task to the gateway."""
44+
try:
45+
url = self._construct_url(host=self._gateway_url, route=GatewayRoutes.ADD_TASK)
46+
payload = {"prompt": text_prompt}
47+
headers = {"x-api-key": self._gateway_api_key}
48+
response = self._http2_client.post(url=url, json=payload, headers=headers)
49+
print(response.text)
50+
response.raise_for_status()
51+
data = json.loads(response.text)
52+
return GatewayTask.model_validate(data)
53+
except Exception as e:
54+
raise GatewayAddTaskError(f"Gateway: error to add task: {e}") from e
55+
56+
def get_status(self, *, task: GatewayTask) -> GatewayTaskStatusResponse:
57+
"""Gets the status of a task."""
58+
try:
59+
url = self._construct_url(host=self._gateway_url, route=GatewayRoutes.GET_STATUS, id=task.id)
60+
headers = {"x-api-key": self._gateway_api_key}
61+
response = self._http2_client.get(url=url, headers=headers)
62+
response.raise_for_status()
63+
data = json.loads(response.text)
64+
reason = data.get("reason", None)
65+
return GatewayTaskStatusResponse(status=data["status"], reason=reason)
66+
except Exception as e:
67+
raise GatewayGetStatusError(f"Gateway: error to get status: {e}") from e
68+
69+
def get_result(self, *, task: GatewayTask) -> bytes:
70+
"""Gets generated 3D asset in spz format."""
71+
try:
72+
url = self._construct_url(host=self._gateway_url, route=GatewayRoutes.GET_RESULT, id=task.id)
73+
headers = {"x-api-key": self._gateway_api_key}
74+
response = self._http2_client.get(url=url, headers=headers)
75+
response.raise_for_status()
76+
if response.headers.get('content-disposition', '').startswith('attachment'):
77+
return cast(bytes, response.content)
78+
raise GatewayNoAttachmentError()
79+
except Exception as e:
80+
raise GatewayGetResultError(f"Gateway: error to get result: {e}") from e
81+
82+
def _construct_url(self, *, host: str, route: GatewayRoutes, **kwargs: Any) -> str:
83+
return f"{host}{route.value}?{urlencode(kwargs)}"
84+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from enum import Enum
2+
3+
4+
class GatewayRoutes(Enum):
5+
"""Routes for gateway nodes"""
6+
7+
ADD_TASK = "/add_task"
8+
"""Send text prompt to gateway to generate 3D asset."""
9+
GET_STATUS = "/get_status"
10+
"""Get status of the task."""
11+
GET_RESULT = "/get_result"
12+
"""Get result of the generation in spz format."""
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from enum import Enum
2+
3+
from pydantic import BaseModel
4+
5+
6+
class GatewayTaskStatus(Enum):
7+
"""Status of the task in gateway"""
8+
9+
NO_RESULT = "NoResult"
10+
FAILURE = "Failure"
11+
PARTIAL_RESULT = "PartialResult"
12+
SUCCESS = "Success"
13+
14+
15+
class GatewayTaskStatusResponse(BaseModel):
16+
"""Response from the gateway for the task status"""
17+
18+
status: GatewayTaskStatus
19+
"""Status of the task in gateway"""
20+
reason: str | None = None
21+
"""Reason for the status. If status is not SUCCESS, reason is not None."""
22+
23+
24+
class GatewayTask(BaseModel):
25+
"""Task for the gateway"""
26+
27+
id: str
28+
"""Unique id of the task in gateway"""
29+
prompt: str
30+
"""Text prompt to generate the asset 3D asset"""
31+
# TODO: in what format should we store in bunny net.
32+
result: bytes | None = None
33+
"""Result of the task in gateway in spz format"""
34+
task_status: GatewayTaskStatus = GatewayTaskStatus.NO_RESULT
35+
"""Status of the task in gateway"""

fourofour_3d_gen/gaussian_splatting.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
RECOMMENDED_MAX_GAUSSIANS = 200_000
1010

1111

12-
def import_gs(filepath: str, name: str, winner_hotkey: str = ""):
12+
def import_gs(filepath: str, name: str):
1313

1414
if "GaussianSplatting" not in bpy.data.node_groups:
1515
script_file = os.path.realpath(__file__)
@@ -131,8 +131,6 @@ def import_gs(filepath: str, name: str, winner_hotkey: str = ""):
131131
obj.rotation_euler = (-np.pi / 2, 0, 0)
132132
obj.rotation_euler[0] = 1.5708
133133

134-
obj["Bittensor Miner"] = winner_hotkey
135-
136134
print("Mesh attributes added in", time.time() - start_time, "seconds")
137135

138136
setup_nodes(obj)

0 commit comments

Comments
 (0)