@@ -121,10 +121,10 @@ def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
121
121
else :
122
122
return "Generating"
123
123
124
- async def create_generate_task (self , images = None , seed = 1 , material = "PBR" , quality = "medium" , tier = "Regular" , mesh_mode = "Quad" , ** kwargs ):
124
+ async def create_generate_task (self , images = None , seed = 1 , material = "PBR" , quality_override = 18000 , tier = "Regular" , mesh_mode = "Quad" , TAPose = False , ** kwargs ):
125
125
if images is None :
126
126
raise Exception ("Rodin 3D generate requires at least 1 image." )
127
- if len (images ) >= 5 :
127
+ if len (images ) > 5 :
128
128
raise Exception ("Rodin 3D generate requires up to 5 image." )
129
129
130
130
path = "/proxy/rodin/api/v2/rodin"
@@ -139,8 +139,9 @@ async def create_generate_task(self, images=None, seed=1, material="PBR", qualit
139
139
seed = seed ,
140
140
tier = tier ,
141
141
material = material ,
142
- quality = quality ,
143
- mesh_mode = mesh_mode
142
+ quality_override = quality_override ,
143
+ mesh_mode = mesh_mode ,
144
+ TAPose = TAPose ,
144
145
),
145
146
files = [
146
147
(
@@ -211,23 +212,36 @@ async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadRespon
211
212
return await operation .execute ()
212
213
213
214
def get_quality_mode (self , poly_count ):
214
- if poly_count == "200K-Triangle" :
215
+ polycount = poly_count .split ("-" )
216
+ poly = polycount [1 ]
217
+ count = polycount [0 ]
218
+ if poly == "Triangle" :
215
219
mesh_mode = "Raw"
216
- quality = "medium"
220
+ elif poly == "Quad" :
221
+ mesh_mode = "Quad"
217
222
else :
218
223
mesh_mode = "Quad"
219
- if poly_count == "4K-Quad" :
220
- quality = "extra-low"
221
- elif poly_count == "8K-Quad" :
222
- quality = "low"
223
- elif poly_count == "18K-Quad" :
224
- quality = "medium"
225
- elif poly_count == "50K-Quad" :
226
- quality = "high"
227
- else :
228
- quality = "medium"
229
-
230
- return mesh_mode , quality
224
+
225
+ if count == "4K" :
226
+ quality_override = 4000
227
+ elif count == "8K" :
228
+ quality_override = 8000
229
+ elif count == "18K" :
230
+ quality_override = 18000
231
+ elif count == "50K" :
232
+ quality_override = 50000
233
+ elif count == "2K" :
234
+ quality_override = 2000
235
+ elif count == "20K" :
236
+ quality_override = 20000
237
+ elif count == "150K" :
238
+ quality_override = 150000
239
+ elif count == "500K" :
240
+ quality_override = 500000
241
+ else :
242
+ quality_override = 18000
243
+
244
+ return mesh_mode , quality_override
231
245
232
246
async def download_files (self , url_list ):
233
247
save_path = os .path .join (comfy_paths .get_output_directory (), "Rodin3D" , datetime .datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" ))
@@ -300,9 +314,9 @@ async def api_call(
300
314
m_images = []
301
315
for i in range (num_images ):
302
316
m_images .append (Images [i ])
303
- mesh_mode , quality = self .get_quality_mode (Polygon_count )
317
+ mesh_mode , quality_override = self .get_quality_mode (Polygon_count )
304
318
task_uuid , subscription_key = await self .create_generate_task (images = m_images , seed = Seed , material = Material_Type ,
305
- quality = quality , tier = tier , mesh_mode = mesh_mode ,
319
+ quality_override = quality_override , tier = tier , mesh_mode = mesh_mode ,
306
320
** kwargs )
307
321
await self .poll_for_task_status (subscription_key , ** kwargs )
308
322
download_list = await self .get_rodin_download_list (task_uuid , ** kwargs )
@@ -346,9 +360,9 @@ async def api_call(
346
360
m_images = []
347
361
for i in range (num_images ):
348
362
m_images .append (Images [i ])
349
- mesh_mode , quality = self .get_quality_mode (Polygon_count )
363
+ mesh_mode , quality_override = self .get_quality_mode (Polygon_count )
350
364
task_uuid , subscription_key = await self .create_generate_task (images = m_images , seed = Seed , material = Material_Type ,
351
- quality = quality , tier = tier , mesh_mode = mesh_mode ,
365
+ quality_override = quality_override , tier = tier , mesh_mode = mesh_mode ,
352
366
** kwargs )
353
367
await self .poll_for_task_status (subscription_key , ** kwargs )
354
368
download_list = await self .get_rodin_download_list (task_uuid , ** kwargs )
@@ -392,9 +406,9 @@ async def api_call(
392
406
m_images = []
393
407
for i in range (num_images ):
394
408
m_images .append (Images [i ])
395
- mesh_mode , quality = self .get_quality_mode (Polygon_count )
409
+ mesh_mode , quality_override = self .get_quality_mode (Polygon_count )
396
410
task_uuid , subscription_key = await self .create_generate_task (images = m_images , seed = Seed , material = Material_Type ,
397
- quality = quality , tier = tier , mesh_mode = mesh_mode ,
411
+ quality_override = quality_override , tier = tier , mesh_mode = mesh_mode ,
398
412
** kwargs )
399
413
await self .poll_for_task_status (subscription_key , ** kwargs )
400
414
download_list = await self .get_rodin_download_list (task_uuid , ** kwargs )
@@ -446,24 +460,99 @@ async def api_call(
446
460
for i in range (num_images ):
447
461
m_images .append (Images [i ])
448
462
material_type = "PBR"
449
- quality = "medium"
463
+ quality_override = 18000
450
464
mesh_mode = "Quad"
451
465
task_uuid , subscription_key = await self .create_generate_task (
452
- images = m_images , seed = Seed , material = material_type , quality = quality , tier = tier , mesh_mode = mesh_mode , ** kwargs
466
+ images = m_images , seed = Seed , material = material_type , quality_override = quality_override , tier = tier , mesh_mode = mesh_mode , ** kwargs
453
467
)
454
468
await self .poll_for_task_status (subscription_key , ** kwargs )
455
469
download_list = await self .get_rodin_download_list (task_uuid , ** kwargs )
456
470
model = await self .download_files (download_list )
457
471
458
472
return (model ,)
459
473
474
+ class Rodin3D_Gen2 (Rodin3DAPI ):
475
+ @classmethod
476
+ def INPUT_TYPES (s ):
477
+ return {
478
+ "required" : {
479
+ "Images" :
480
+ (
481
+ IO .IMAGE ,
482
+ {
483
+ "forceInput" :True ,
484
+ }
485
+ )
486
+ },
487
+ "optional" : {
488
+ "Seed" : (
489
+ IO .INT ,
490
+ {
491
+ "default" :0 ,
492
+ "min" :0 ,
493
+ "max" :65535 ,
494
+ "display" :"number"
495
+ }
496
+ ),
497
+ "Material_Type" : (
498
+ IO .COMBO ,
499
+ {
500
+ "options" : ["PBR" , "Shaded" ],
501
+ "default" : "PBR"
502
+ }
503
+ ),
504
+ "Polygon_count" : (
505
+ IO .COMBO ,
506
+ {
507
+ "options" : ["4K-Quad" , "8K-Quad" , "18K-Quad" , "50K-Quad" , "2K-Triangle" , "20K-Triangle" , "150K-Triangle" , "500K-Triangle" ],
508
+ "default" : "500K-Triangle"
509
+ }
510
+ ),
511
+ "TAPose" : (
512
+ IO .BOOLEAN ,
513
+ {
514
+ "default" : False ,
515
+ }
516
+ )
517
+ },
518
+ "hidden" : {
519
+ "auth_token" : "AUTH_TOKEN_COMFY_ORG" ,
520
+ "comfy_api_key" : "API_KEY_COMFY_ORG" ,
521
+ },
522
+ }
523
+
524
+ async def api_call (
525
+ self ,
526
+ Images ,
527
+ Seed ,
528
+ Material_Type ,
529
+ Polygon_count ,
530
+ TAPose ,
531
+ ** kwargs
532
+ ):
533
+ tier = "Gen-2"
534
+ num_images = Images .shape [0 ]
535
+ m_images = []
536
+ for i in range (num_images ):
537
+ m_images .append (Images [i ])
538
+ mesh_mode , quality_override = self .get_quality_mode (Polygon_count )
539
+ task_uuid , subscription_key = await self .create_generate_task (images = m_images , seed = Seed , material = Material_Type ,
540
+ quality_override = quality_override , tier = tier , mesh_mode = mesh_mode , TAPose = TAPose ,
541
+ ** kwargs )
542
+ await self .poll_for_task_status (subscription_key , ** kwargs )
543
+ download_list = await self .get_rodin_download_list (task_uuid , ** kwargs )
544
+ model = await self .download_files (download_list )
545
+
546
+ return (model ,)
547
+
460
548
# A dictionary that contains all nodes you want to export with their names
461
549
# NOTE: names should be globally unique
462
550
NODE_CLASS_MAPPINGS = {
463
551
"Rodin3D_Regular" : Rodin3D_Regular ,
464
552
"Rodin3D_Detail" : Rodin3D_Detail ,
465
553
"Rodin3D_Smooth" : Rodin3D_Smooth ,
466
554
"Rodin3D_Sketch" : Rodin3D_Sketch ,
555
+ "Rodin3D_Gen2" : Rodin3D_Gen2 ,
467
556
}
468
557
469
558
# A dictionary that contains the friendly/humanly readable titles for the nodes
@@ -472,4 +561,5 @@ async def api_call(
472
561
"Rodin3D_Detail" : "Rodin 3D Generate - Detail Generate" ,
473
562
"Rodin3D_Smooth" : "Rodin 3D Generate - Smooth Generate" ,
474
563
"Rodin3D_Sketch" : "Rodin 3D Generate - Sketch Generate" ,
564
+ "Rodin3D_Gen2" : "Rodin 3D Generate - Gen-2 Generate" ,
475
565
}
0 commit comments