Skip to content

Commit 341b4ad

Browse files
authored
Rodin3D - add [Rodin3D Gen-2 generate] api-node (#9994)
* update Rodin api node * update rodin3d gen2 api node * fix images limited bug
1 parent b873051 commit 341b4ad

File tree

2 files changed

+118
-27
lines changed

2 files changed

+118
-27
lines changed

comfy_api_nodes/apis/rodin_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ class Rodin3DGenerateRequest(BaseModel):
99
seed: int = Field(..., description="seed_")
1010
tier: str = Field(..., description="Tier of generation.")
1111
material: str = Field(..., description="The material type.")
12-
quality: str = Field(..., description="The generation quality of the mesh.")
12+
quality_override: int = Field(..., description="The poly count of the mesh.")
1313
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
14+
TAPose: Optional[bool] = Field(None, description="")
1415

1516
class GenerateJobsData(BaseModel):
1617
uuids: List[str] = Field(..., description="str LIST")

comfy_api_nodes/nodes_rodin.py

Lines changed: 116 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
121121
else:
122122
return "Generating"
123123

124-
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
124+
async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs):
125125
if images is None:
126126
raise Exception("Rodin 3D generate requires at least 1 image.")
127-
if len(images) >= 5:
127+
if len(images) > 5:
128128
raise Exception("Rodin 3D generate requires up to 5 image.")
129129

130130
path = "/proxy/rodin/api/v2/rodin"
@@ -139,8 +139,9 @@ async def create_generate_task(self, images=None, seed=1, material="PBR", qualit
139139
seed=seed,
140140
tier=tier,
141141
material=material,
142-
quality=quality,
143-
mesh_mode=mesh_mode
142+
quality_override=quality_override,
143+
mesh_mode=mesh_mode,
144+
TAPose=TAPose,
144145
),
145146
files=[
146147
(
@@ -211,23 +212,36 @@ async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadRespon
211212
return await operation.execute()
212213

213214
def get_quality_mode(self, poly_count):
214-
if poly_count == "200K-Triangle":
215+
polycount = poly_count.split("-")
216+
poly = polycount[1]
217+
count = polycount[0]
218+
if poly == "Triangle":
215219
mesh_mode = "Raw"
216-
quality = "medium"
220+
elif poly == "Quad":
221+
mesh_mode = "Quad"
217222
else:
218223
mesh_mode = "Quad"
219-
if poly_count == "4K-Quad":
220-
quality = "extra-low"
221-
elif poly_count == "8K-Quad":
222-
quality = "low"
223-
elif poly_count == "18K-Quad":
224-
quality = "medium"
225-
elif poly_count == "50K-Quad":
226-
quality = "high"
227-
else:
228-
quality = "medium"
229-
230-
return mesh_mode, quality
224+
225+
if count == "4K":
226+
quality_override = 4000
227+
elif count == "8K":
228+
quality_override = 8000
229+
elif count == "18K":
230+
quality_override = 18000
231+
elif count == "50K":
232+
quality_override = 50000
233+
elif count == "2K":
234+
quality_override = 2000
235+
elif count == "20K":
236+
quality_override = 20000
237+
elif count == "150K":
238+
quality_override = 150000
239+
elif count == "500K":
240+
quality_override = 500000
241+
else:
242+
quality_override = 18000
243+
244+
return mesh_mode, quality_override
231245

232246
async def download_files(self, url_list):
233247
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
@@ -300,9 +314,9 @@ async def api_call(
300314
m_images = []
301315
for i in range(num_images):
302316
m_images.append(Images[i])
303-
mesh_mode, quality = self.get_quality_mode(Polygon_count)
317+
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
304318
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
305-
quality=quality, tier=tier, mesh_mode=mesh_mode,
319+
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
306320
**kwargs)
307321
await self.poll_for_task_status(subscription_key, **kwargs)
308322
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
@@ -346,9 +360,9 @@ async def api_call(
346360
m_images = []
347361
for i in range(num_images):
348362
m_images.append(Images[i])
349-
mesh_mode, quality = self.get_quality_mode(Polygon_count)
363+
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
350364
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
351-
quality=quality, tier=tier, mesh_mode=mesh_mode,
365+
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
352366
**kwargs)
353367
await self.poll_for_task_status(subscription_key, **kwargs)
354368
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
@@ -392,9 +406,9 @@ async def api_call(
392406
m_images = []
393407
for i in range(num_images):
394408
m_images.append(Images[i])
395-
mesh_mode, quality = self.get_quality_mode(Polygon_count)
409+
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
396410
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
397-
quality=quality, tier=tier, mesh_mode=mesh_mode,
411+
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
398412
**kwargs)
399413
await self.poll_for_task_status(subscription_key, **kwargs)
400414
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
@@ -446,24 +460,99 @@ async def api_call(
446460
for i in range(num_images):
447461
m_images.append(Images[i])
448462
material_type = "PBR"
449-
quality = "medium"
463+
quality_override = 18000
450464
mesh_mode = "Quad"
451465
task_uuid, subscription_key = await self.create_generate_task(
452-
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs
466+
images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs
453467
)
454468
await self.poll_for_task_status(subscription_key, **kwargs)
455469
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
456470
model = await self.download_files(download_list)
457471

458472
return (model,)
459473

474+
class Rodin3D_Gen2(Rodin3DAPI):
475+
@classmethod
476+
def INPUT_TYPES(s):
477+
return {
478+
"required": {
479+
"Images":
480+
(
481+
IO.IMAGE,
482+
{
483+
"forceInput":True,
484+
}
485+
)
486+
},
487+
"optional": {
488+
"Seed": (
489+
IO.INT,
490+
{
491+
"default":0,
492+
"min":0,
493+
"max":65535,
494+
"display":"number"
495+
}
496+
),
497+
"Material_Type": (
498+
IO.COMBO,
499+
{
500+
"options": ["PBR", "Shaded"],
501+
"default": "PBR"
502+
}
503+
),
504+
"Polygon_count": (
505+
IO.COMBO,
506+
{
507+
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
508+
"default": "500K-Triangle"
509+
}
510+
),
511+
"TAPose": (
512+
IO.BOOLEAN,
513+
{
514+
"default": False,
515+
}
516+
)
517+
},
518+
"hidden": {
519+
"auth_token": "AUTH_TOKEN_COMFY_ORG",
520+
"comfy_api_key": "API_KEY_COMFY_ORG",
521+
},
522+
}
523+
524+
async def api_call(
525+
self,
526+
Images,
527+
Seed,
528+
Material_Type,
529+
Polygon_count,
530+
TAPose,
531+
**kwargs
532+
):
533+
tier = "Gen-2"
534+
num_images = Images.shape[0]
535+
m_images = []
536+
for i in range(num_images):
537+
m_images.append(Images[i])
538+
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
539+
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
540+
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose,
541+
**kwargs)
542+
await self.poll_for_task_status(subscription_key, **kwargs)
543+
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
544+
model = await self.download_files(download_list)
545+
546+
return (model,)
547+
460548
# A dictionary that contains all nodes you want to export with their names
461549
# NOTE: names should be globally unique
462550
NODE_CLASS_MAPPINGS = {
463551
"Rodin3D_Regular": Rodin3D_Regular,
464552
"Rodin3D_Detail": Rodin3D_Detail,
465553
"Rodin3D_Smooth": Rodin3D_Smooth,
466554
"Rodin3D_Sketch": Rodin3D_Sketch,
555+
"Rodin3D_Gen2": Rodin3D_Gen2,
467556
}
468557

469558
# A dictionary that contains the friendly/humanly readable titles for the nodes
@@ -472,4 +561,5 @@ async def api_call(
472561
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
473562
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
474563
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
564+
"Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate",
475565
}

0 commit comments

Comments
 (0)