diff --git a/docker-compose.yaml b/docker-compose.yaml index 84a9aae1a..f32270dcb 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -214,6 +214,8 @@ services: audio: image: ${AUDIO_IMAGE:-nvcr.io/nim/nvidia/parakeet-1-1b-ctc-en-us}:${AUDIO_TAG:-1.4.0} shm_size: 2gb + ports: + - "${AUDIO_GRPC_PORT:-50051}:50051" expose: - "50051" # grpc - "9000" # http diff --git a/retriever/audio_stage_config.yaml b/retriever/audio_stage_config.yaml new file mode 100644 index 000000000..28b0817e1 --- /dev/null +++ b/retriever/audio_stage_config.yaml @@ -0,0 +1,49 @@ +# Example config for: +# - `retriever audio stage run --config --input ` +# +# This YAML is parsed into `nv_ingest_api.internal.schemas.extract.extract_audio_schema.AudioExtractorSchema` +# via `retriever.audio.config.load_audio_extractor_schema_from_dict`. +# +# IMPORTANT: +# `audio_extraction_config.audio_endpoints` must provide at least one endpoint +# (gRPC or HTTP). Both cannot be null/empty. HTTP is not supported for audio; +# use gRPC. + +# Optional worker settings +max_queue_size: 1 +n_workers: 2 +raise_on_failure: false + +# Audio extraction configuration (Riva / Parakeet NIM). +audio_extraction_config: + # Optional auth token for secured services (NIM / NVCF) + auth_token: null + + # Tuple/list in the form: [grpc, http] + # Riva/Parakeet ASR endpoint. Only gRPC is supported for audio. + # + # For the provided docker-compose.yaml the host-mapped ports are: + # - gRPC: audio:50051 (inside docker network) + # - gRPC: localhost:50051 (from host) + # + # For NVCF hosted endpoints: + # - gRPC: grpc.nvcf.nvidia.com:443 + # audio_endpoints: ["audio:50051", null] + audio_endpoints: ["localhost:50051", null] + + + # Optional; if omitted it is inferred from which endpoint is present. + # Only "grpc" is supported for audio. + audio_infer_protocol: grpc + + # Optional NVCF function ID (required when using grpc.nvcf.nvidia.com) + function_id: null + + # SSL settings (auto-detected for NVCF endpoints) + use_ssl: null + ssl_cert: null + + # If true, each speech segment (sentence) becomes a separate row with + # start_time / end_time metadata. If false (default), one row per file + # containing the full transcript. + segment_audio: false diff --git a/retriever/src/retriever/audio/__init__.py b/retriever/src/retriever/audio/__init__.py new file mode 100644 index 000000000..d073ed325 --- /dev/null +++ b/retriever/src/retriever/audio/__init__.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from .transcribe import ( + audio_bytes_to_transcript_df, + audio_file_to_transcript_df, +) + +__all__ = [ + "audio_bytes_to_transcript_df", + "audio_file_to_transcript_df", +] diff --git a/retriever/src/retriever/audio/ray_data.py b/retriever/src/retriever/audio/ray_data.py new file mode 100644 index 000000000..36c472abf --- /dev/null +++ b/retriever/src/retriever/audio/ray_data.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Ray Data adapter for audio: AudioTranscribeActor turns bytes+path batches into transcript rows. +""" + +from __future__ import annotations + +from typing import List + +import pandas as pd + +from retriever.params import AudioExtractParams + +from .transcribe import audio_bytes_to_transcript_df + + +class AudioTranscribeActor: + """Ray Data map_batches callable: DataFrame with bytes, path -> DataFrame of transcript chunks. + + Each output row has: text, content, path, page_number, metadata + (same shape as audio_file_to_transcript_df). + """ + + def __init__(self, params: AudioExtractParams | None = None) -> None: + self._params = params or AudioExtractParams() + + def __call__(self, batch_df: pd.DataFrame) -> pd.DataFrame: + if not isinstance(batch_df, pd.DataFrame) or batch_df.empty: + return pd.DataFrame(columns=["text", "content", "path", "page_number", "metadata"]) + + params = self._params + out_dfs: List[pd.DataFrame] = [] + for _, row in batch_df.iterrows(): + raw = row.get("bytes") + path = row.get("path") + if raw is None or path is None: + continue + path_str = str(path) if path is not None else "" + try: + chunk_df = audio_bytes_to_transcript_df( + raw, + path_str, + grpc_endpoint=params.grpc_endpoint, + auth_token=params.auth_token, + function_id=params.function_id, + use_ssl=params.use_ssl, + ssl_cert=params.ssl_cert, + segment_audio=params.segment_audio, + max_tokens=params.max_tokens, + overlap_tokens=params.overlap_tokens, + tokenizer_model_id=params.tokenizer_model_id, + tokenizer_cache_dir=params.tokenizer_cache_dir, + ) + if not chunk_df.empty: + out_dfs.append(chunk_df) + except Exception: + continue + if not out_dfs: + return pd.DataFrame(columns=["text", "content", "path", "page_number", "metadata"]) + return pd.concat(out_dfs, ignore_index=True) diff --git a/retriever/src/retriever/audio/transcribe.py b/retriever/src/retriever/audio/transcribe.py new file mode 100644 index 000000000..ecd89dd54 --- /dev/null +++ b/retriever/src/retriever/audio/transcribe.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Audio transcription via Riva/Parakeet NIM for the retriever pipeline. + +Produces chunk DataFrames compatible with embed_text_from_primitives_df +and the LanceDB row builder (text, path, page_number, metadata). +""" + +from __future__ import annotations + +import base64 +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pandas as pd + +from nv_ingest_api.internal.primitives.nim.model_interface.parakeet import create_audio_inference_client + +logger = logging.getLogger(__name__) + +DEFAULT_TOKENIZER_MODEL_ID = "nvidia/llama-3.2-nv-embedqa-1b-v2" + + +def _get_tokenizer(model_id: str, cache_dir: Optional[str] = None): # noqa: ANN201 + """Lazy-load HuggingFace tokenizer.""" + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir) + + +def _split_text_by_tokens( + text: str, + *, + tokenizer: Any, + max_tokens: int, + overlap_tokens: int = 0, +) -> List[str]: + """Split text into chunks by token count with optional overlap.""" + if not text or not text.strip(): + return [] + if max_tokens <= 0: + raise ValueError("max_tokens must be positive") + + enc = tokenizer.encode(text, add_special_tokens=False) + if not enc: + return [] + + step = max(1, max_tokens - overlap_tokens) + chunks: List[str] = [] + start = 0 + while start < len(enc): + end = min(start + max_tokens, len(enc)) + chunk_ids = enc[start:end] + chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True) + if chunk_text.strip(): + chunks.append(chunk_text) + start += step + if end >= len(enc): + break + + return chunks if chunks else [text] + + +def _transcribe_audio_bytes( + audio_bytes: bytes, + *, + grpc_endpoint: str, + auth_token: Optional[str] = None, + function_id: Optional[str] = None, + use_ssl: Optional[bool] = None, + ssl_cert: Optional[str] = None, +) -> tuple: + """Create a ParakeetClient and transcribe raw audio bytes. + + Returns (segments, transcript) where segments is a list of + ``{"start": float, "end": float, "text": str}`` dicts and + transcript is the full transcript string. + """ + b64_audio = base64.b64encode(audio_bytes).decode("utf-8") + client = create_audio_inference_client( + (grpc_endpoint, ""), + infer_protocol="grpc", + auth_token=auth_token, + function_id=function_id, + use_ssl=use_ssl if use_ssl is not None else False, + ssl_cert=ssl_cert, + ) + segments, transcript = client.infer( + b64_audio, + model_name="parakeet", + stage_name="audio_extraction", + ) + return segments, transcript + + +def _build_rows( + transcript: str, + segments: list, + path: str, + *, + segment_audio: bool = False, + max_tokens: Optional[int] = None, + overlap_tokens: int = 0, + tokenizer_model_id: Optional[str] = None, + tokenizer_cache_dir: Optional[str] = None, +) -> List[Dict[str, Any]]: + """Convert transcript/segments into row dicts for the output DataFrame.""" + rows: List[Dict[str, Any]] = [] + + if segment_audio: + for i, seg in enumerate(segments): + text = seg.get("text", "") + if not text.strip(): + continue + rows.append( + { + "text": text, + "content": text, + "path": path, + # "page_number": None, + "metadata": { + "source_path": path, + "chunk_index": i, + "content_metadata": { + "type": "audio", + "start_time": seg.get("start"), + "end_time": seg.get("end"), + }, + "content": text, + }, + } + ) + elif max_tokens is not None: + model_id = tokenizer_model_id or DEFAULT_TOKENIZER_MODEL_ID + tokenizer = _get_tokenizer(model_id, cache_dir=tokenizer_cache_dir) + chunk_texts = _split_text_by_tokens( + transcript, + tokenizer=tokenizer, + max_tokens=max_tokens, + overlap_tokens=overlap_tokens, + ) + for i, chunk in enumerate(chunk_texts): + rows.append( + { + "text": chunk, + "content": chunk, + "path": path, + # "page_number": None, + "metadata": { + "source_path": path, + "chunk_index": i, + "content_metadata": {"type": "audio"}, + "content": chunk, + }, + } + ) + else: + rows.append( + { + "text": transcript, + "content": transcript, + "path": path, + # "page_number": None, + "metadata": { + "source_path": path, + "chunk_index": 0, + "content_metadata": {"type": "audio"}, + "content": transcript, + }, + } + ) + + return rows + + +_EMPTY_DF_COLUMNS = ["text", "content", "path", "page_number", "metadata"] + + +def audio_file_to_transcript_df( + path: str, + grpc_endpoint: str = "audio:50051", + auth_token: Optional[str] = None, + function_id: Optional[str] = None, + use_ssl: Optional[bool] = None, + ssl_cert: Optional[str] = None, + segment_audio: bool = False, + max_tokens: Optional[int] = None, + overlap_tokens: int = 0, + tokenizer_model_id: Optional[str] = None, + tokenizer_cache_dir: Optional[str] = None, + **kwargs: Any, +) -> pd.DataFrame: + """Read an audio file, transcribe via Parakeet/Riva, return a chunk DataFrame. + + Parameters + ---------- + path : str + Path to an audio file (mp3, wav, etc.). + grpc_endpoint : str + Riva/Parakeet gRPC endpoint. + segment_audio : bool + If True, each speech segment becomes its own row. + max_tokens : int, optional + If set (and segment_audio is False), chunk the transcript by token + count using the same logic as txt extraction. + overlap_tokens : int + Token overlap between consecutive chunks (only used with max_tokens). + + Returns + ------- + pd.DataFrame + Columns: text, content, path, page_number, metadata. + """ + abs_path = str(Path(path).resolve()) + audio_bytes = Path(abs_path).read_bytes() + + segments, transcript = _transcribe_audio_bytes( + audio_bytes, + grpc_endpoint=grpc_endpoint, + auth_token=auth_token, + function_id=function_id, + use_ssl=use_ssl, + ssl_cert=ssl_cert, + ) + + if not transcript or not transcript.strip(): + return pd.DataFrame(columns=_EMPTY_DF_COLUMNS).astype({"page_number": "int64"}) + + rows = _build_rows( + transcript, + segments, + abs_path, + segment_audio=segment_audio, + max_tokens=max_tokens, + overlap_tokens=overlap_tokens, + tokenizer_model_id=tokenizer_model_id, + tokenizer_cache_dir=tokenizer_cache_dir, + ) + + if not rows: + return pd.DataFrame(columns=_EMPTY_DF_COLUMNS).astype({"page_number": "int64"}) + + return pd.DataFrame(rows) + + +def audio_bytes_to_transcript_df( + content_bytes: bytes, + path: str, + grpc_endpoint: str = "audio:50051", + auth_token: Optional[str] = None, + function_id: Optional[str] = None, + use_ssl: Optional[bool] = None, + ssl_cert: Optional[str] = None, + segment_audio: bool = False, + max_tokens: Optional[int] = None, + overlap_tokens: int = 0, + tokenizer_model_id: Optional[str] = None, + tokenizer_cache_dir: Optional[str] = None, + **kwargs: Any, +) -> pd.DataFrame: + """Transcribe audio from raw bytes (for Ray Data batch mode). + + Same as :func:`audio_file_to_transcript_df` but accepts bytes directly + instead of reading from disk. + """ + abs_path = str(Path(path).resolve()) + + segments, transcript = _transcribe_audio_bytes( + content_bytes, + grpc_endpoint=grpc_endpoint, + auth_token=auth_token, + function_id=function_id, + use_ssl=use_ssl, + ssl_cert=ssl_cert, + ) + + if not transcript or not transcript.strip(): + return pd.DataFrame(columns=_EMPTY_DF_COLUMNS).astype({"page_number": "int64"}) + + rows = _build_rows( + transcript, + segments, + abs_path, + segment_audio=segment_audio, + max_tokens=max_tokens, + overlap_tokens=overlap_tokens, + tokenizer_model_id=tokenizer_model_id, + tokenizer_cache_dir=tokenizer_cache_dir, + ) + + if not rows: + return pd.DataFrame(columns=_EMPTY_DF_COLUMNS).astype({"page_number": "int64"}) + + return pd.DataFrame(rows) diff --git a/retriever/src/retriever/examples/batch_pipeline.py b/retriever/src/retriever/examples/batch_pipeline.py index 5a9ad955e..d2dac77b7 100644 --- a/retriever/src/retriever/examples/batch_pipeline.py +++ b/retriever/src/retriever/examples/batch_pipeline.py @@ -138,7 +138,11 @@ def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]: try: db = lancedb.connect(uri) table = db.open_table(table_name) - df = table.to_pandas()[["source_id", "page_number", "metadata"]] + cols = ["source_id", "page_number", "metadata"] + all_cols = table.to_pandas().columns.tolist() + if "text" in all_cols: + cols.append("text") + df = table.to_pandas()[cols] except Exception: return None @@ -167,6 +171,8 @@ def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]: "ocr_chart_total": 0, "ocr_infographic_total": 0, "page_elements_by_label": defaultdict(int), + "is_audio": False, + "audio_words": 0, }, ) @@ -190,11 +196,20 @@ def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]: _to_int(count, default=0), ) + content_meta = meta.get("content_metadata") + if isinstance(content_meta, dict) and content_meta.get("type") == "audio": + entry["is_audio"] = True + text = str(getattr(row, "text", "") or "") + entry["audio_words"] = max(entry["audio_words"], len(text.split())) + pe_by_label_totals: dict[str, int] = defaultdict(int) page_elements_total = 0 ocr_table_total = 0 ocr_chart_total = 0 ocr_infographic_total = 0 + audio_files = 0 + audio_segments = 0 + audio_words_total = 0 for page_entry in per_page.values(): page_elements_total += int(page_entry["page_elements_total"]) ocr_table_total += int(page_entry["ocr_table_total"]) @@ -202,6 +217,12 @@ def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]: ocr_infographic_total += int(page_entry["ocr_infographic_total"]) for label, count in page_entry["page_elements_by_label"].items(): pe_by_label_totals[label] += int(count) + if page_entry["is_audio"]: + audio_segments += 1 + audio_words_total += page_entry["audio_words"] + + audio_paths = {k[0] for k, e in per_page.items() if e["is_audio"]} + audio_files = len(audio_paths) return { "pages_seen": int(len(per_page)), @@ -210,6 +231,9 @@ def _collect_detection_summary(uri: str, table_name: str) -> Optional[dict]: "ocr_table_total_detections": int(ocr_table_total), "ocr_chart_total_detections": int(ocr_chart_total), "ocr_infographic_total_detections": int(ocr_infographic_total), + "audio_files": audio_files, + "audio_segments": audio_segments, + "audio_words_total": audio_words_total, } @@ -231,6 +255,12 @@ def _print_detection_summary(summary: Optional[dict]) -> None: for label, count in by_label.items(): print(f" {label}: {count}") + audio_files = summary.get("audio_files", 0) + if audio_files > 0: + print(f" Audio files transcribed: {audio_files}") + print(f" Audio transcript segments: {summary.get('audio_segments', 0)}") + print(f" Audio transcript words: {summary.get('audio_words_total', 0)}") + def _write_detection_summary(path: Path, summary: Optional[dict]) -> None: target = Path(path).expanduser().resolve() @@ -339,14 +369,14 @@ def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: def main( input_dir: Path = typer.Argument( ..., - help="Directory containing PDFs, .txt, .html, or .doc/.pptx files to ingest.", + help="Directory containing PDFs, .txt, .html, .doc/.pptx, or .mp3/.wav files to ingest.", path_type=Path, exists=True, ), input_type: str = typer.Option( "pdf", "--input-type", - help="Input format: 'pdf', 'txt', 'html', or 'doc'. Use 'txt' for .txt, 'html' for .html (markitdown -> chunks), 'doc' for .docx/.pptx (converted to PDF via LibreOffice).", # noqa: E501 + help="Input format: 'pdf', 'txt', 'html', 'doc', or 'audio'. Use 'txt' for .txt, 'html' for .html (markitdown -> chunks), 'doc' for .docx/.pptx (converted to PDF via LibreOffice), 'audio' for .mp3/.wav (transcribed via Riva/Parakeet).", # noqa: E501 ), ray_address: Optional[str] = typer.Option( None, @@ -531,6 +561,16 @@ def main( dir_okay=False, help="Optional JSON file path to write end-of-run detection counts summary.", ), + audio_grpc_endpoint: str = typer.Option( + "audio:50051", + "--audio-grpc-endpoint", + help="Riva/Parakeet gRPC endpoint for audio transcription (used with --input-type audio).", + ), + segment_audio: bool = typer.Option( + False, + "--segment-audio", + help="If set, each speech segment becomes its own row (used with --input-type audio).", + ), ) -> None: log_handle, original_stdout, original_stderr = _configure_logging(log_file) try: @@ -603,6 +643,35 @@ def main( ) ) ) + elif input_type == "audio": + import glob as _glob + + audio_exts = ("*.mp3", "*.wav") + audio_files = [f for ext in audio_exts for f in _glob.glob(str(input_dir / ext))] + if not audio_files: + raise typer.BadParameter(f"No audio files (.mp3/.wav) found in {input_dir}") + ingestor = create_ingestor( + run_mode="batch", + params=IngestorCreateParams(ray_address=ray_address, ray_log_to_driver=ray_log_to_driver), + ) + ingestor = ( + ingestor.files(audio_files) + .extract_audio( + grpc_endpoint=audio_grpc_endpoint, + segment_audio=segment_audio, + ) + .embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url)) + .vdb_upload( + VdbUploadParams( + lancedb={ + "lancedb_uri": lancedb_uri, + "table_name": LANCEDB_TABLE, + "overwrite": True, + "create_index": True, + } + ) + ) + ) elif input_type == "doc": # DOCX/PPTX: same pipeline as PDF; DocToPdfConversionActor converts before split. doc_globs = [str(input_dir / "*.docx"), str(input_dir / "*.pptx")] @@ -791,11 +860,12 @@ def main( if not no_recall_details: print("\nPer-query retrieval details:") missed_gold: list[tuple[str, str]] = [] - ext = ( - ".html" - if input_type == "html" - else (".txt" if input_type == "txt" else (".docx" if input_type == "doc" else ".pdf")) - ) + ext = { + "txt": ".txt", + "html": ".html", + "doc": ".docx", + "audio": ".mp3", + }.get(input_type, ".pdf") for i, (q, g, hits) in enumerate( zip( _df_query["query"].astype(str).tolist(), diff --git a/retriever/src/retriever/examples/inprocess_pipeline.py b/retriever/src/retriever/examples/inprocess_pipeline.py index 9eb61875f..67e252244 100644 --- a/retriever/src/retriever/examples/inprocess_pipeline.py +++ b/retriever/src/retriever/examples/inprocess_pipeline.py @@ -100,14 +100,14 @@ def _hit_key_and_distance(hit: dict) -> tuple[str | None, float | None]: def main( input_dir: Path = typer.Argument( ..., - help="Directory containing PDFs, .txt, .html, or .doc/.pptx files to ingest.", + help="Directory containing PDFs, .txt, .html, .doc/.pptx, or .mp3/.wav files to ingest.", path_type=Path, exists=True, ), input_type: str = typer.Option( "pdf", "--input-type", - help="Input format: 'pdf', 'txt', 'html', or 'doc'. Use 'txt' for .txt, 'html' for .html (markitdown -> chunks), 'doc' for .docx/.pptx (converted to PDF via LibreOffice).", # noqa: E501 + help="Input format: 'pdf', 'txt', 'html', 'doc', or 'audio'. Use 'txt' for .txt, 'html' for .html (markitdown -> chunks), 'doc' for .docx/.pptx (converted to PDF via LibreOffice), 'audio' for .mp3/.wav (transcribed via Riva/Parakeet).", # noqa: E501 ), query_csv: Path = typer.Option( "bo767_query_gt.csv", @@ -150,6 +150,16 @@ def main( "--embed-invoke-url", help="Optional remote endpoint URL for embedding model inference.", ), + audio_grpc_endpoint: str = typer.Option( + "audio:50051", + "--audio-grpc-endpoint", + help="Riva/Parakeet gRPC endpoint for audio transcription (used with --input-type audio).", + ), + segment_audio: bool = typer.Option( + False, + "--segment-audio", + help="If set, each speech segment becomes its own row (used with --input-type audio).", + ), embed_model_name: str = typer.Option( "nvidia/llama-3.2-nv-embedqa-1b-v2", "--embed-model-name", @@ -204,6 +214,28 @@ def main( ) ) ) + elif input_type == "audio": + import glob as _glob + + audio_exts = ("*.mp3", "*.wav") + audio_files = [f for ext in audio_exts for f in _glob.glob(str(input_dir / ext))] + if not audio_files: + raise typer.BadParameter(f"No audio files found in {input_dir}") + ingestor = create_ingestor(run_mode="inprocess") + ingestor = ( + ingestor.files(audio_files) + .extract_audio( + grpc_endpoint=audio_grpc_endpoint, + segment_audio=segment_audio, + ) + .embed(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url) + .vdb_upload( + lancedb_uri=LANCEDB_URI, + table_name=LANCEDB_TABLE, + overwrite=True, + create_index=True, + ) + ) elif input_type == "doc": # DOCX/PPTX: same pipeline as PDF; inprocess loader converts to PDF then splits. doc_globs = [str(input_dir / "*.docx"), str(input_dir / "*.pptx")] @@ -329,11 +361,12 @@ def main( hit = _is_hit_at_k(g, top_keys, cfg.top_k) if not no_recall_details: - ext = ( - ".txt" - if input_type == "txt" - else (".html" if input_type == "html" else (".docx" if input_type == "doc" else ".pdf")) - ) + ext = { + "txt": ".txt", + "html": ".html", + "doc": ".docx", + "audio": ".mp3", + }.get(input_type, ".pdf") print(f"\nQuery {i}: {q}") print(f" Gold: {g} (file: {doc}{ext}, page: {page})") print(f" Hit@{cfg.top_k}: {hit}") @@ -348,11 +381,12 @@ def main( print(f" {rank:02d}. {key} distance={dist:.6f}") if not hit: - ext = ( - ".txt" - if input_type == "txt" - else (".html" if input_type == "html" else (".docx" if input_type == "doc" else ".pdf")) - ) + ext = { + "txt": ".txt", + "html": ".html", + "doc": ".docx", + "audio": ".mp3", + }.get(input_type, ".pdf") missed_gold.append((f"{doc}{ext}", str(page))) missed_unique = sorted(set(missed_gold), key=lambda x: (x[0], x[1])) diff --git a/retriever/src/retriever/ingest-config.yaml b/retriever/src/retriever/ingest-config.yaml index 54b4f7b40..695e55363 100644 --- a/retriever/src/retriever/ingest-config.yaml +++ b/retriever/src/retriever/ingest-config.yaml @@ -73,6 +73,22 @@ html: tokenizer_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2 encoding: utf-8 +# Optional config for `retriever audio stage run` and .extract_audio() API +audio: + max_queue_size: 1 + n_workers: 2 + raise_on_failure: false + + audio_extraction_config: + auth_token: null + # [grpc, http] — only gRPC is supported for audio + audio_endpoints: ["audio:50051", null] + # audio_infer_protocol: grpc + function_id: null + use_ssl: null + ssl_cert: null + segment_audio: false + table: # Example config for: # - `retriever table stage run --config --input ` diff --git a/retriever/src/retriever/ingest_modes/batch.py b/retriever/src/retriever/ingest_modes/batch.py index 769180cfb..0ef6375a7 100644 --- a/retriever/src/retriever/ingest_modes/batch.py +++ b/retriever/src/retriever/ingest_modes/batch.py @@ -31,6 +31,7 @@ from retriever.pdf.split import PDFSplitActor from ..ingest import Ingestor +from ..params import AudioExtractParams from ..params import EmbedParams from ..params import ExtractParams from ..params import HtmlChunkParams @@ -194,6 +195,10 @@ def _build_rows(self, df: Any) -> list: entries = getattr(row, ocr_col, None) if isinstance(entries, list): metadata_obj[f"ocr_{ocr_col}_detections"] = int(len(entries)) + if isinstance(meta, dict): + content_meta = meta.get("content_metadata") + if isinstance(content_meta, dict): + metadata_obj["content_metadata"] = content_meta source_obj = {"source_id": str(path)} row_out = { @@ -695,6 +700,30 @@ def extract_html(self, params: HtmlChunkParams | None = None, **kwargs: Any) -> ) return self + def extract_audio(self, params: AudioExtractParams | None = None, **kwargs: Any) -> "BatchIngestor": + """ + Configure audio pipeline: read_binary_files -> AudioTranscribeActor (bytes -> transcript rows). + + Use with .files("*.mp3|*.wav").extract_audio(...).embed().vdb_upload().ingest(). + Do not call .extract() when using .extract_audio(). + """ + from retriever.audio.ray_data import AudioTranscribeActor + + self._pipeline_type = "audio" + resolved = _coerce_params(params, AudioExtractParams, kwargs) + self._extract_audio_kwargs = resolved.model_dump(mode="python") + self._tasks.append(("extract_audio", dict(self._extract_audio_kwargs))) + + self._rd_dataset = self._rd_dataset.map_batches( + AudioTranscribeActor, + batch_size=1, + batch_format="pandas", + num_cpus=1, + num_gpus=0, + fn_constructor_kwargs={"params": AudioExtractParams(**self._extract_audio_kwargs)}, + ) + return self + def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "BatchIngestor": """ Add a text-embedding stage to the batch pipeline. diff --git a/retriever/src/retriever/ingest_modes/inprocess.py b/retriever/src/retriever/ingest_modes/inprocess.py index f3d620e4c..b94036af5 100644 --- a/retriever/src/retriever/ingest_modes/inprocess.py +++ b/retriever/src/retriever/ingest_modes/inprocess.py @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - """ In-process runmode. @@ -52,10 +48,12 @@ from ..params import IngestExecuteParams from ..params import TextChunkParams from ..params import VdbUploadParams +from ..params import AudioExtractParams from ..pdf.extract import pdf_extraction from ..pdf.split import _split_pdf_to_single_page_bytes, pdf_path_to_pages_df from ..txt import txt_file_to_chunks_df from ..html import html_file_to_chunks_df +from ..audio import audio_file_to_transcript_df _CONTENT_COLUMNS = ("table", "chart", "infographic") @@ -882,6 +880,8 @@ def _collect_summary_from_df(df: pd.DataFrame) -> dict: "ocr_chart": 0, "ocr_infographic": 0, "pe_by_label": defaultdict(int), + "is_audio": False, + "audio_words": 0, }, ) @@ -921,8 +921,20 @@ def _collect_summary_from_df(df: pd.DataFrame) -> dict: c = 0 entry["pe_by_label"][str(label)] = max(entry["pe_by_label"][str(label)], c) + content_type = None + content_meta = meta.get("content_metadata") + if isinstance(content_meta, dict): + content_type = content_meta.get("type") + if content_type == "audio": + entry["is_audio"] = True + text = str(row_dict.get("text") or "") + entry["audio_words"] = max(entry["audio_words"], len(text.split())) + pe_by_label_totals: dict[str, int] = defaultdict(int) pe_total = ocr_table_total = ocr_chart_total = ocr_infographic_total = 0 + audio_files = 0 + audio_segments = 0 + audio_words_total = 0 for e in per_page.values(): pe_total += e["pe"] ocr_table_total += e["ocr_table"] @@ -930,6 +942,12 @@ def _collect_summary_from_df(df: pd.DataFrame) -> dict: ocr_infographic_total += e["ocr_infographic"] for label, count in e["pe_by_label"].items(): pe_by_label_totals[label] += count + if e["is_audio"]: + audio_segments += 1 + audio_words_total += e["audio_words"] + + audio_paths = {k[0] for k, e in per_page.items() if e["is_audio"]} + audio_files = len(audio_paths) return { "pages_seen": len(per_page), @@ -938,6 +956,9 @@ def _collect_summary_from_df(df: pd.DataFrame) -> dict: "ocr_table_total_detections": ocr_table_total, "ocr_chart_total_detections": ocr_chart_total, "ocr_infographic_total_detections": ocr_infographic_total, + "audio_files": audio_files, + "audio_segments": audio_segments, + "audio_words_total": audio_words_total, } @@ -965,6 +986,12 @@ def _print_ingest_summary(results: list, elapsed_s: float) -> None: for label, count in by_label.items(): print(f" {label}: {count}") + audio_files = summary.get("audio_files", 0) + if audio_files > 0: + print(f" Audio files transcribed: {audio_files}") + print(f" Audio transcript segments: {summary.get('audio_segments', 0)}") + print(f" Audio transcript words: {summary.get('audio_words_total', 0)}") + pages = summary["pages_seen"] if elapsed_s > 0 and pages > 0: pps = pages / elapsed_s @@ -977,8 +1004,8 @@ def _print_ingest_summary(results: list, elapsed_s: float) -> None: class InProcessIngestor(Ingestor): RUN_MODE = "inprocess" - def __init__(self, documents: Optional[List[str]] = None) -> None: - super().__init__(documents=documents) + def __init__(self, documents: Optional[List[str]] = None, **kwargs: Any) -> None: + super().__init__(documents=documents, **kwargs) # Keep backwards-compatibility with code that inspects `Ingestor._documents` # by ensuring both names refer to the same list. @@ -987,8 +1014,9 @@ def __init__(self, documents: Optional[List[str]] = None) -> None: # Builder-style configuration recorded for later execution (TBD). self._tasks: List[tuple[Callable[..., Any], dict[str, Any]]] = [] - # Pipeline type: "pdf" (extract), "txt" (extract_txt), or "html" (extract_html). Loader dispatch in ingest(). - self._pipeline_type: Literal["pdf", "txt", "html"] = "pdf" + # Pipeline type: "pdf" (extract), "txt" (extract_txt), "html" (extract_html), + # or "audio" (extract_audio). Loader dispatch in ingest(). + self._pipeline_type: Literal["pdf", "txt", "html", "audio"] = "pdf" self._extract_txt_kwargs: Dict[str, Any] = {} self._extract_html_kwargs: Dict[str, Any] = {} @@ -1124,14 +1152,13 @@ def _detect_kwargs_with_model(model_obj: Any, *, stage_name: str, allow_remote: if ocr_invoke_url: self._tasks.append((ocr_page_elements, {"model": None, **ocr_flags})) else: - ocr_model_dir = ( - kwargs.get("ocr_model_dir") - or os.environ.get("RETRIEVER_NEMOTRON_OCR_MODEL_DIR", "").strip() - or os.environ.get("NEMOTRON_OCR_MODEL_DIR", "").strip() - or os.environ.get("NEMOTRON_OCR_V1_MODEL_DIR", "").strip() - ) - model = NemotronOCRV1(model_dir=str(ocr_model_dir)) if ocr_model_dir else NemotronOCRV1() - self._tasks.append((ocr_page_elements, {"model": model, **ocr_flags})) + ocr_model_dir = os.environ.get("NEMOTRON_OCR_MODEL_DIR", "") + if not ocr_model_dir: + raise RuntimeError( + "NEMOTRON_OCR_MODEL_DIR environment variable must be set to " + "the path of the Nemotron OCR v1 model directory." + ) + self._tasks.append((ocr_page_elements, {"model": NemotronOCRV1(model_dir=ocr_model_dir), **ocr_flags})) return self @@ -1159,6 +1186,18 @@ def extract_html(self, params: HtmlChunkParams | None = None, **kwargs: Any) -> self._extract_html_kwargs = resolved.model_dump(mode="python") return self + def extract_audio(self, params: AudioExtractParams | None = None, **kwargs: Any) -> "InProcessIngestor": + """ + Configure audio ingestion: transcribe via Riva/Parakeet NIM. + + Use with .files("*.mp3").extract_audio(...).embed().vdb_upload().ingest(). + Do not call .extract() when using .extract_audio(). + """ + self._pipeline_type = "audio" + resolved = _coerce_params(params, AudioExtractParams, kwargs) + self._extract_audio_kwargs = resolved.model_dump(mode="python") + return self + def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessIngestor": """ Configure embedding for in-process execution. @@ -1316,8 +1355,14 @@ def ingest(self, params: IngestExecuteParams | None = None, **kwargs: Any) -> li docs = list(self._documents) + # Page-level chunking (_iter_page_chunks) relies on pdfium and only + # works for PDF (and DOCX/PPTX converted to PDF). Other pipeline + # types (txt, html, audio) use _loader dispatch in the fully- + # sequential path below. + _supports_page_chunks = self._pipeline_type == "pdf" + # -- Parallel execution branch ------------------------------------ - if parallel: + if parallel and _supports_page_chunks: if gpu_devices and len(gpu_devices) >= 1 and gpu_tasks: # Pipelined: GPU workers load models while CPU runs, # each completed chunk goes to GPU immediately. @@ -1464,7 +1509,7 @@ def _on_gpu_done(sid: int) -> None: return results # -- Sequential execution branch (default) ------------------------ - use_multi_gpu_seq = gpu_devices and len(gpu_devices) >= 1 and gpu_tasks + use_multi_gpu_seq = _supports_page_chunks and gpu_devices and len(gpu_devices) >= 1 and gpu_tasks if use_multi_gpu_seq: # Pipelined: GPU workers process earlier chunks while CPU @@ -1563,12 +1608,17 @@ def _loader(p: str) -> pd.DataFrame: elif self._pipeline_type == "html": def _loader(p: str) -> pd.DataFrame: - return html_file_to_chunks_df(p, params=HtmlChunkParams(**self._extract_html_kwargs)) + return html_file_to_chunks_df(p, **self._extract_html_kwargs) + + elif self._pipeline_type == "audio": + + def _loader(p: str) -> pd.DataFrame: + return audio_file_to_transcript_df(p, **self._extract_audio_kwargs) else: def _loader(p: str) -> pd.DataFrame: - return txt_file_to_chunks_df(p, params=TextChunkParams(**self._extract_txt_kwargs)) + return txt_file_to_chunks_df(p, **self._extract_txt_kwargs) for doc_path in doc_iter: initial_df = _loader(doc_path) diff --git a/retriever/src/retriever/params/__init__.py b/retriever/src/retriever/params/__init__.py index bdaa9b2a3..390ce1381 100644 --- a/retriever/src/retriever/params/__init__.py +++ b/retriever/src/retriever/params/__init__.py @@ -2,6 +2,7 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from .models import AudioExtractParams from .models import BatchTuningParams from .models import ChartParams from .models import EmbedParams @@ -25,6 +26,7 @@ from .models import VdbUploadParams __all__ = [ + "AudioExtractParams", "BatchTuningParams", "ChartParams", "EmbedParams", diff --git a/retriever/src/retriever/params/models.py b/retriever/src/retriever/params/models.py index a64ef03a8..328b9a147 100644 --- a/retriever/src/retriever/params/models.py +++ b/retriever/src/retriever/params/models.py @@ -72,6 +72,19 @@ class HtmlChunkParams(TextChunkParams): pass +class AudioExtractParams(_ParamsModel): + grpc_endpoint: str = "audio:50051" + auth_token: Optional[str] = None + function_id: Optional[str] = None + use_ssl: Optional[bool] = None + ssl_cert: Optional[str] = None + segment_audio: bool = False + max_tokens: Optional[int] = None + overlap_tokens: int = 0 + tokenizer_model_id: Optional[str] = None + tokenizer_cache_dir: Optional[str] = None + + class LanceDbParams(_ParamsModel): lancedb_uri: str = "lancedb" table_name: str = "nv-ingest" diff --git a/retriever/tests/test_audio_transcribe.py b/retriever/tests/test_audio_transcribe.py new file mode 100644 index 000000000..89e49aafe --- /dev/null +++ b/retriever/tests/test_audio_transcribe.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for retriever.audio.transcribe: + _split_text_by_tokens, _build_rows, audio_file_to_transcript_df, audio_bytes_to_transcript_df. +""" + +from pathlib import Path +from unittest.mock import patch + +import pandas as pd +import pytest + +from retriever.audio.transcribe import ( + _build_rows, + _split_text_by_tokens, + audio_bytes_to_transcript_df, + audio_file_to_transcript_df, +) + + +# --------------------------------------------------------------------------- +# Mock tokenizer (same pattern as test_txt_split.py) +# --------------------------------------------------------------------------- + + +class _MockTokenizer: + """Minimal tokenizer: encode = split on spaces, decode = join.""" + + def encode(self, text: str, add_special_tokens: bool = False): + return text.split() + + def decode(self, ids, skip_special_tokens: bool = True): + if isinstance(ids, (list, range)): + return " ".join(str(i) for i in ids) + return str(ids) + + +SAMPLE_TRANSCRIPT = "the quick brown fox jumped over the lazy dog" +SAMPLE_SEGMENTS = [ + {"start": 0.0, "end": 1.5, "text": "the quick brown fox"}, + {"start": 1.5, "end": 3.0, "text": "jumped over the lazy dog"}, +] + + +# --------------------------------------------------------------------------- +# _split_text_by_tokens +# --------------------------------------------------------------------------- + + +class TestSplitTextByTokens: + def test_empty_text(self): + tok = _MockTokenizer() + assert _split_text_by_tokens("", tokenizer=tok, max_tokens=5) == [] + assert _split_text_by_tokens(" ", tokenizer=tok, max_tokens=5) == [] + + def test_single_chunk(self): + tok = _MockTokenizer() + chunks = _split_text_by_tokens("a b c", tokenizer=tok, max_tokens=10) + assert len(chunks) == 1 + assert chunks[0] == "a b c" + + def test_multiple_chunks_no_overlap(self): + tok = _MockTokenizer() + text = "a b c d e f g h i" + chunks = _split_text_by_tokens(text, tokenizer=tok, max_tokens=3, overlap_tokens=0) + assert len(chunks) == 3 + assert chunks[0] == "a b c" + assert chunks[1] == "d e f" + assert chunks[2] == "g h i" + + def test_multiple_chunks_with_overlap(self): + tok = _MockTokenizer() + text = "a b c d e f" + chunks = _split_text_by_tokens(text, tokenizer=tok, max_tokens=3, overlap_tokens=1) + assert len(chunks) >= 2 + assert chunks[0] == "a b c" + assert chunks[1] == "c d e" + + def test_invalid_max_tokens(self): + tok = _MockTokenizer() + with pytest.raises(ValueError, match="max_tokens must be positive"): + _split_text_by_tokens("hello", tokenizer=tok, max_tokens=0) + with pytest.raises(ValueError, match="max_tokens must be positive"): + _split_text_by_tokens("hello", tokenizer=tok, max_tokens=-1) + + +# --------------------------------------------------------------------------- +# _build_rows — default (whole transcript) +# --------------------------------------------------------------------------- + + +class TestBuildRowsDefault: + def test_single_row(self): + rows = _build_rows(SAMPLE_TRANSCRIPT, SAMPLE_SEGMENTS, "/audio/test.mp3") + assert len(rows) == 1 + row = rows[0] + assert row["text"] == SAMPLE_TRANSCRIPT + assert row["content"] == SAMPLE_TRANSCRIPT + assert row["path"] == "/audio/test.mp3" + meta = row["metadata"] + assert meta["source_path"] == "/audio/test.mp3" + assert meta["chunk_index"] == 0 + assert meta["content_metadata"]["type"] == "audio" + assert meta["content"] == SAMPLE_TRANSCRIPT + + def test_empty_transcript_returns_empty(self): + rows = _build_rows("", [], "/audio/empty.mp3") + assert len(rows) == 1 + assert rows[0]["text"] == "" + + def test_whitespace_only_transcript(self): + rows = _build_rows(" ", [], "/audio/blank.mp3") + assert len(rows) == 1 + assert rows[0]["text"] == " " + + +# --------------------------------------------------------------------------- +# _build_rows — segmented +# --------------------------------------------------------------------------- + + +class TestBuildRowsSegmented: + def test_segment_rows(self): + rows = _build_rows(SAMPLE_TRANSCRIPT, SAMPLE_SEGMENTS, "/audio/test.mp3", segment_audio=True) + assert len(rows) == 2 + assert rows[0]["text"] == "the quick brown fox" + assert rows[0]["metadata"]["content_metadata"]["start_time"] == 0.0 + assert rows[0]["metadata"]["content_metadata"]["end_time"] == 1.5 + assert rows[1]["text"] == "jumped over the lazy dog" + assert rows[1]["metadata"]["chunk_index"] == 1 + + def test_empty_segments_skipped(self): + segs = [ + {"start": 0.0, "end": 1.0, "text": "hello"}, + {"start": 1.0, "end": 2.0, "text": " "}, + {"start": 2.0, "end": 3.0, "text": "world"}, + ] + rows = _build_rows("hello world", segs, "/audio/test.mp3", segment_audio=True) + assert len(rows) == 2 + assert rows[0]["text"] == "hello" + assert rows[1]["text"] == "world" + + +# --------------------------------------------------------------------------- +# _build_rows — token-chunked +# --------------------------------------------------------------------------- + + +class TestBuildRowsTokenChunked: + @patch("retriever.audio.transcribe._get_tokenizer", return_value=_MockTokenizer()) + def test_chunked_rows(self, _mock_tok): + rows = _build_rows(SAMPLE_TRANSCRIPT, SAMPLE_SEGMENTS, "/audio/test.mp3", max_tokens=3) + assert len(rows) >= 2 + for i, row in enumerate(rows): + assert row["metadata"]["chunk_index"] == i + assert row["metadata"]["content_metadata"]["type"] == "audio" + assert row["text"] == row["content"] + assert row["text"] == row["metadata"]["content"] + + +# --------------------------------------------------------------------------- +# audio_file_to_transcript_df +# --------------------------------------------------------------------------- + + +def _mock_transcribe(_audio_bytes, *, grpc_endpoint, **kw): + return SAMPLE_SEGMENTS, SAMPLE_TRANSCRIPT + + +class TestAudioFileToTranscriptDf: + @patch("retriever.audio.transcribe._transcribe_audio_bytes", side_effect=_mock_transcribe) + def test_basic(self, _mock, tmp_path: Path): + f = tmp_path / "clip.wav" + f.write_bytes(b"\x00\x01\x02\x03") + df = audio_file_to_transcript_df(str(f), grpc_endpoint="localhost:50051") + assert isinstance(df, pd.DataFrame) + assert set(df.columns) >= {"text", "content", "path", "metadata"} + assert len(df) == 1 + assert df["text"].iloc[0] == SAMPLE_TRANSCRIPT + assert df["path"].iloc[0] == str(f.resolve()) + meta = df["metadata"].iloc[0] + assert meta["content_metadata"]["type"] == "audio" + + @patch("retriever.audio.transcribe._transcribe_audio_bytes", return_value=([], "")) + def test_empty_transcript(self, _mock, tmp_path: Path): + f = tmp_path / "silent.wav" + f.write_bytes(b"\x00") + df = audio_file_to_transcript_df(str(f), grpc_endpoint="localhost:50051") + assert isinstance(df, pd.DataFrame) + assert len(df) == 0 + assert list(df.columns) == ["text", "content", "path", "page_number", "metadata"] + + @patch("retriever.audio.transcribe._transcribe_audio_bytes", side_effect=_mock_transcribe) + def test_segmented(self, _mock, tmp_path: Path): + f = tmp_path / "clip.mp3" + f.write_bytes(b"\xff\xfb") + df = audio_file_to_transcript_df(str(f), grpc_endpoint="localhost:50051", segment_audio=True) + assert len(df) == 2 + assert df["text"].iloc[0] == "the quick brown fox" + assert df["metadata"].iloc[0]["content_metadata"]["start_time"] == 0.0 + + @patch("retriever.audio.transcribe._transcribe_audio_bytes", side_effect=_mock_transcribe) + @patch("retriever.audio.transcribe._get_tokenizer", return_value=_MockTokenizer()) + def test_token_chunked(self, _mock_tok, _mock_transcribe, tmp_path: Path): + f = tmp_path / "clip.wav" + f.write_bytes(b"\x00\x01") + df = audio_file_to_transcript_df(str(f), grpc_endpoint="localhost:50051", max_tokens=4) + assert len(df) >= 2 + for i in range(len(df)): + assert df["metadata"].iloc[i]["chunk_index"] == i + + @patch("retriever.audio.transcribe._transcribe_audio_bytes", side_effect=_mock_transcribe) + def test_extra_kwargs_ignored(self, _mock, tmp_path: Path): + f = tmp_path / "clip.wav" + f.write_bytes(b"\x00") + df = audio_file_to_transcript_df(str(f), grpc_endpoint="localhost:50051", unknown_param="ignored") + assert len(df) == 1 + + +# --------------------------------------------------------------------------- +# audio_bytes_to_transcript_df +# --------------------------------------------------------------------------- + + +class TestAudioBytesToTranscriptDf: + @patch("retriever.audio.transcribe._transcribe_audio_bytes", side_effect=_mock_transcribe) + def test_basic(self, _mock, tmp_path: Path): + path = str(tmp_path / "virtual.wav") + df = audio_bytes_to_transcript_df(b"\x00\x01\x02\x03", path, grpc_endpoint="localhost:50051") + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + assert df["text"].iloc[0] == SAMPLE_TRANSCRIPT + assert df["metadata"].iloc[0]["content_metadata"]["type"] == "audio" + + @patch("retriever.audio.transcribe._transcribe_audio_bytes", return_value=([], "")) + def test_empty_transcript(self, _mock, tmp_path: Path): + path = str(tmp_path / "silent.wav") + df = audio_bytes_to_transcript_df(b"\x00", path, grpc_endpoint="localhost:50051") + assert isinstance(df, pd.DataFrame) + assert len(df) == 0 + + @patch("retriever.audio.transcribe._transcribe_audio_bytes", side_effect=_mock_transcribe) + def test_segmented(self, _mock, tmp_path: Path): + path = str(tmp_path / "clip.mp3") + df = audio_bytes_to_transcript_df(b"\xff\xfb", path, grpc_endpoint="localhost:50051", segment_audio=True) + assert len(df) == 2 + assert df["text"].iloc[1] == "jumped over the lazy dog"