Skip to content

Commit 20da166

Browse files
committed
feat: created optional clear_vector_db() KFP component to allow users to easily clear the database during running pipeline
1 parent bbb2a10 commit 20da166

File tree

3 files changed

+1473
-314
lines changed

3 files changed

+1473
-314
lines changed

demos/kfp/docling/asr-conversion/README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ The pipeline enables rich RAG applications that can answer questions about spoke
106106
- `embed_model_id`: Embedding model to use (default: `ibm-granite/granite-embedding-125m-english`)
107107
- `max_tokens`: Maximum tokens per chunk (default: 512)
108108
- `use_gpu`: Whether to use GPU for processing (default: true)
109+
- `clean_vector_db`: if True, the vector database will be cleared during running the pipeline
109110

110111

111112
### Creating the Pipeline for running on GPU node
@@ -124,9 +125,6 @@ python3 docling_asr_convert_pipeline.py
124125

125126

126127
### Creating the Pipeline for running on CPU only
127-
128-
129-
130128
```
131129
# Install dependencies for pipeline
132130
cd demos/kfp/docling/asr-conversion

demos/kfp/docling/asr-conversion/docling_asr_convert_pipeline.py

Lines changed: 130 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ def docling_convert_pipeline(
489489
embed_model_id: str = "ibm-granite/granite-embedding-125m-english",
490490
max_tokens: int = 512,
491491
use_gpu: bool = True, # use only if you have additional gpu worker
492+
clean_vector_db: bool = True, # if True, the vector database will be cleared during running the pipeline
492493
) -> None:
493494
"""
494495
Converts audio recordings to text using Docling ASR and generates embeddings
@@ -500,74 +501,138 @@ def docling_convert_pipeline(
500501
:param embed_model_id: Model ID for embedding generation
501502
:param max_tokens: Maximum number of tokens per chunk
502503
:param use_gpu: boolean to enable/disable gpu in the docling workers
504+
:param clean_vector_db: boolean to enable/disable clearing the vector database before running the pipeline
503505
:return:
504506
"""
505-
clear_task = clear_vector_db(
506-
service_url=service_url,
507-
vector_db_id=vector_db_id,
508-
)
509-
clear_task.set_caching_options(False)
510-
511-
register_task = register_vector_db(
512-
service_url=service_url,
513-
vector_db_id=vector_db_id,
514-
embed_model_id=embed_model_id,
515-
).after(clear_task)
516-
register_task.set_caching_options(False)
507+
with dsl.If(clean_vector_db == True):
508+
clear_task = clear_vector_db(
509+
service_url=service_url,
510+
vector_db_id=vector_db_id,
511+
)
512+
clear_task.set_caching_options(False)
513+
514+
register_task = register_vector_db(
515+
service_url=service_url,
516+
vector_db_id=vector_db_id,
517+
embed_model_id=embed_model_id,
518+
).after(clear_task)
519+
register_task.set_caching_options(False)
520+
521+
import_task = import_audio_files(
522+
base_url=base_url,
523+
audio_filenames=audio_filenames,
524+
)
525+
import_task.set_caching_options(True)
526+
527+
audio_splits = create_audio_splits(
528+
input_path=import_task.output,
529+
num_splits=num_workers,
530+
).set_caching_options(True)
531+
532+
with dsl.ParallelFor(audio_splits.output) as audio_split:
533+
with dsl.If(use_gpu == True):
534+
convert_task = docling_convert_and_ingest_audio(
535+
input_path=import_task.output,
536+
audio_split=audio_split,
537+
embed_model_id=embed_model_id,
538+
max_tokens=max_tokens,
539+
service_url=service_url,
540+
vector_db_id=vector_db_id,
541+
)
542+
convert_task.set_caching_options(False)
543+
convert_task.set_cpu_request("500m")
544+
convert_task.set_cpu_limit("4")
545+
convert_task.set_memory_request("2Gi")
546+
convert_task.set_memory_limit("6Gi")
547+
convert_task.set_accelerator_type("nvidia.com/gpu")
548+
convert_task.set_accelerator_limit(1)
549+
add_toleration_json(
550+
convert_task,
551+
[
552+
{
553+
"effect": "NoSchedule",
554+
"key": "nvidia.com/gpu",
555+
"operator": "Exists",
556+
}
557+
],
558+
)
559+
add_node_selector_json(convert_task, {})
560+
with dsl.Else():
561+
convert_task = docling_convert_and_ingest_audio(
562+
input_path=import_task.output,
563+
audio_split=audio_split,
564+
embed_model_id=embed_model_id,
565+
max_tokens=max_tokens,
566+
service_url=service_url,
567+
vector_db_id=vector_db_id,
568+
)
569+
convert_task.set_caching_options(False)
570+
convert_task.set_cpu_request("500m")
571+
convert_task.set_cpu_limit("4")
572+
convert_task.set_memory_request("2Gi")
573+
convert_task.set_memory_limit("6Gi")
574+
575+
with dsl.Else():
576+
register_task = register_vector_db(
577+
service_url=service_url,
578+
vector_db_id=vector_db_id,
579+
embed_model_id=embed_model_id,
580+
)
581+
register_task.set_caching_options(False)
517582

518-
import_task = import_audio_files(
519-
base_url=base_url,
520-
audio_filenames=audio_filenames,
521-
)
522-
import_task.set_caching_options(True)
523-
524-
audio_splits = create_audio_splits(
525-
input_path=import_task.output,
526-
num_splits=num_workers,
527-
).set_caching_options(True)
528-
529-
with dsl.ParallelFor(audio_splits.output) as audio_split:
530-
with dsl.If(use_gpu == True):
531-
convert_task = docling_convert_and_ingest_audio(
532-
input_path=import_task.output,
533-
audio_split=audio_split,
534-
embed_model_id=embed_model_id,
535-
max_tokens=max_tokens,
536-
service_url=service_url,
537-
vector_db_id=vector_db_id,
538-
)
539-
convert_task.set_caching_options(False)
540-
convert_task.set_cpu_request("500m")
541-
convert_task.set_cpu_limit("4")
542-
convert_task.set_memory_request("2Gi")
543-
convert_task.set_memory_limit("6Gi")
544-
convert_task.set_accelerator_type("nvidia.com/gpu")
545-
convert_task.set_accelerator_limit(1)
546-
add_toleration_json(
547-
convert_task,
548-
[
549-
{
550-
"effect": "NoSchedule",
551-
"key": "nvidia.com/gpu",
552-
"operator": "Exists",
553-
}
554-
],
555-
)
556-
add_node_selector_json(convert_task, {})
557-
with dsl.Else():
558-
convert_task = docling_convert_and_ingest_audio(
559-
input_path=import_task.output,
560-
audio_split=audio_split,
561-
embed_model_id=embed_model_id,
562-
max_tokens=max_tokens,
563-
service_url=service_url,
564-
vector_db_id=vector_db_id,
565-
)
566-
convert_task.set_caching_options(False)
567-
convert_task.set_cpu_request("500m")
568-
convert_task.set_cpu_limit("4")
569-
convert_task.set_memory_request("2Gi")
570-
convert_task.set_memory_limit("6Gi")
583+
import_task = import_audio_files(
584+
base_url=base_url,
585+
audio_filenames=audio_filenames,
586+
)
587+
import_task.set_caching_options(True)
588+
589+
audio_splits = create_audio_splits(
590+
input_path=import_task.output,
591+
num_splits=num_workers,
592+
).set_caching_options(True)
593+
594+
with dsl.ParallelFor(audio_splits.output) as audio_split:
595+
with dsl.If(use_gpu == True):
596+
convert_task = docling_convert_and_ingest_audio(
597+
input_path=import_task.output,
598+
audio_split=audio_split,
599+
embed_model_id=embed_model_id,
600+
max_tokens=max_tokens,
601+
service_url=service_url,
602+
vector_db_id=vector_db_id,
603+
)
604+
convert_task.set_caching_options(False)
605+
convert_task.set_cpu_request("500m")
606+
convert_task.set_cpu_limit("4")
607+
convert_task.set_memory_request("2Gi")
608+
convert_task.set_memory_limit("6Gi")
609+
convert_task.set_accelerator_type("nvidia.com/gpu")
610+
convert_task.set_accelerator_limit(1)
611+
add_toleration_json(
612+
convert_task,
613+
[
614+
{
615+
"effect": "NoSchedule",
616+
"key": "nvidia.com/gpu",
617+
"operator": "Exists",
618+
}
619+
],
620+
)
621+
add_node_selector_json(convert_task, {})
622+
with dsl.Else():
623+
convert_task = docling_convert_and_ingest_audio(
624+
input_path=import_task.output,
625+
audio_split=audio_split,
626+
embed_model_id=embed_model_id,
627+
max_tokens=max_tokens,
628+
service_url=service_url,
629+
vector_db_id=vector_db_id,
630+
)
631+
convert_task.set_caching_options(False)
632+
convert_task.set_cpu_request("500m")
633+
convert_task.set_cpu_limit("4")
634+
convert_task.set_memory_request("2Gi")
635+
convert_task.set_memory_limit("6Gi")
571636

572637

573638
if __name__ == "__main__":

0 commit comments

Comments
 (0)