Skip to content

Commit 789fbf5

Browse files
committed
add audio output support; fix output logic when execute error;
1 parent 0b3adbc commit 789fbf5

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

oneapi.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ async def execute_workflow(request):
8181
- images_by_var: Mapped image URLs by variable name {var_name: [string, ...], ...}, only present if there are image results
8282
- videos: List of video URLs [string, ...], only present if there are video results
8383
- videos_by_var: Mapped video URLs by variable name {var_name: [string, ...], ...}, only present if there are video results
84+
- audios: List of audio URLs [string, ...], only present if there are audio results
85+
- audios_by_var: Mapped audio URLs by variable name {var_name: [string, ...], ...}, only present if there are audio results
8486
- texts: List of text outputs [string, ...], only present if there are text results
8587
- texts_by_var: Mapped text outputs by variable name {var_name: [string, ...], ...}, only present if there are text results
8688
@@ -89,8 +91,8 @@ async def execute_workflow(request):
8991
- Output: Use "$output.name" in output node title to specify outputs
9092
- "$output.name" - Marks an output node with a custom output name (added to "images_by_var[name]" or "videos_by_var[name]" or "texts_by_var[name]")
9193
- If no explicit output marker is set, the node_id is used as the variable name
92-
- Any node that produces outputs (images, videos, texts) will be included in results
93-
- images/images_by_var/videos/videos_by_var/texts/texts_by_var fields are only included if there are corresponding results
94+
- Any node that produces outputs (images, videos, audios, texts) will be included in results
95+
- images/images_by_var/videos/videos_by_var/audios/audios_by_var/texts/texts_by_var fields are only included if there are corresponding results
9496
"""
9597
try:
9698
# Get request data
@@ -490,18 +492,20 @@ def _extend_flat_list_from_dict(media_dict):
490492

491493
def _split_media_by_suffix(node_output, base_url):
492494
"""
493-
Split all media entries in node_output into images/videos by file extension.
495+
Split all media entries in node_output into images/videos/audios by file extension.
494496
Args:
495497
node_output: Output dict for a node
496498
base_url: Base URL for constructing file URLs
497499
Returns:
498-
(images: list, videos: list) - lists of URLs
500+
(images: list, videos: list, audios: list) - lists of URLs
499501
"""
500502
image_exts = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tiff'}
501503
video_exts = {'.mp4', '.mov', '.avi', '.webm', '.gif'}
504+
audio_exts = {'.mp3', '.wav', '.flac', '.ogg', '.aac', '.m4a', '.wma', '.opus'}
502505
images = []
503506
videos = []
504-
for media_key in ("images", "gifs"):
507+
audios = []
508+
for media_key in ("images", "gifs", "audio"):
505509
for media_data in node_output.get(media_key, []):
506510
filename = media_data.get("filename")
507511
subfolder = media_data.get("subfolder", "")
@@ -516,7 +520,9 @@ def _split_media_by_suffix(node_output, base_url):
516520
images.append(url)
517521
elif ext in video_exts:
518522
videos.append(url)
519-
return images, videos
523+
elif ext in audio_exts:
524+
audios.append(url)
525+
return images, videos, audios
520526

521527
async def _wait_for_results(prompt_id, timeout=None, request=None, output_id_2_var=None):
522528
"""Wait for workflow execution results, get history using HTTP API"""
@@ -528,6 +534,8 @@ async def _wait_for_results(prompt_id, timeout=None, request=None, output_id_2_v
528534
"images_by_var": {},
529535
"videos": [],
530536
"videos_by_var": {},
537+
"audios": [],
538+
"audios_by_var": {},
531539
"texts": [],
532540
"texts_by_var": {}
533541
}
@@ -552,21 +560,41 @@ async def _wait_for_results(prompt_id, timeout=None, request=None, output_id_2_v
552560
if prompt_id not in history_data:
553561
await asyncio.sleep(1.0)
554562
continue
563+
555564
prompt_history = history_data[prompt_id]
565+
status = prompt_history.get("status")
566+
if status and status.get("status_str") == "error":
567+
result["status"] = "error"
568+
messages = status.get("messages")
569+
if messages:
570+
errors = [
571+
body.get("exception_message")
572+
for type, body in messages
573+
if type == "execution_error"
574+
]
575+
error_message = "\n".join(errors)
576+
else:
577+
error_message = "Unknown error"
578+
result["error"] = error_message
579+
return result
580+
556581
if "outputs" in prompt_history:
557582
result["outputs"] = prompt_history["outputs"]
558583
result["status"] = "completed"
559584

560-
# Collect all image and video outputs by file extension
585+
# Collect all image, video, audio and text outputs by file extension
561586
output_id_2_images = {}
562587
output_id_2_videos = {}
588+
output_id_2_audios = {}
563589
output_id_2_texts = {}
564590
for node_id, node_output in prompt_history["outputs"].items():
565-
images, videos = _split_media_by_suffix(node_output, base_url)
591+
images, videos, audios = _split_media_by_suffix(node_output, base_url)
566592
if images:
567593
output_id_2_images[node_id] = images
568594
if videos:
569595
output_id_2_videos[node_id] = videos
596+
if audios:
597+
output_id_2_audios[node_id] = audios
570598
# Collect text outputs
571599
if "text" in node_output:
572600
# Handle text field as string or list
@@ -586,12 +614,16 @@ async def _wait_for_results(prompt_id, timeout=None, request=None, output_id_2_v
586614
result["videos_by_var"] = _map_outputs_by_var(output_id_2_var, output_id_2_videos)
587615
result["videos"] = _extend_flat_list_from_dict(result["videos_by_var"])
588616

617+
if output_id_2_audios:
618+
result["audios_by_var"] = _map_outputs_by_var(output_id_2_var, output_id_2_audios)
619+
result["audios"] = _extend_flat_list_from_dict(result["audios_by_var"])
620+
589621
# Handle texts/texts_by_var
590622
if output_id_2_texts:
591623
result["texts_by_var"] = _map_outputs_by_var(output_id_2_var, output_id_2_texts)
592624
result["texts"] = _extend_flat_list_from_dict(result["texts_by_var"])
593625

594-
# Remove empty fields for images/videos/texts
626+
# Remove empty fields for images/videos/audios/texts
595627
if not result.get("images"):
596628
result.pop("images", None)
597629
if not result.get("images_by_var"):
@@ -600,6 +632,10 @@ async def _wait_for_results(prompt_id, timeout=None, request=None, output_id_2_v
600632
result.pop("videos", None)
601633
if not result.get("videos_by_var"):
602634
result.pop("videos_by_var", None)
635+
if not result.get("audios"):
636+
result.pop("audios", None)
637+
if not result.get("audios_by_var"):
638+
result.pop("audios_by_var", None)
603639
if not result.get("texts"):
604640
result.pop("texts", None)
605641
if not result.get("texts_by_var"):

0 commit comments

Comments
 (0)