Skip to content

Commit 8e6b61d

Browse files
committed
Update ai search
1 parent b3e3c04 commit 8e6b61d

File tree

5 files changed

+130
-108
lines changed

5 files changed

+130
-108
lines changed

deploy_ai_search/src/deploy_ai_search/ai_search.py

Lines changed: 114 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,10 @@ def get_text_split_skill(
322322

323323
return semantic_text_chunker_skill
324324

325-
def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
326-
"""Get the custom skill for adi.
325+
def get_layout_analysis_skill(
326+
self, chunk_by_page=False, extract_figures=True
327+
) -> WebApiSkill:
328+
"""Get the custom skill for layout analysis.
327329
328330
Args:
329331
-----
@@ -343,25 +345,24 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
343345

344346
if chunk_by_page:
345347
output = [
346-
OutputFieldMappingEntry(name="extracted_content", target_name="chunks")
348+
OutputFieldMappingEntry(name="layout", target_name="page_wise_layout")
347349
]
348350
else:
349-
output = [
350-
OutputFieldMappingEntry(
351-
name="extracted_content", target_name="extracted_content"
352-
)
353-
]
351+
output = [OutputFieldMappingEntry(name="layout", target_name="layout")]
354352

355-
adi_skill = WebApiSkill(
356-
name="ADI Skill",
353+
layout_analysis_skill = WebApiSkill(
354+
name="Layout Analysis Skill",
357355
description="Skill to generate ADI",
358356
context="/document",
359-
uri=self.environment.get_custom_skill_function_url("adi"),
357+
uri=self.environment.get_custom_skill_function_url("layout_analysis"),
360358
timeout="PT230S",
361359
batch_size=batch_size,
362360
degree_of_parallelism=degree_of_parallelism,
363361
http_method="POST",
364-
http_headers={"chunk_by_page": chunk_by_page},
362+
http_headers={
363+
"chunk_by_page": chunk_by_page,
364+
"extract_figures": extract_figures,
365+
},
365366
inputs=[
366367
InputFieldMappingEntry(
367368
name="source", source="/document/metadata_storage_path"
@@ -371,100 +372,150 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
371372
)
372373

373374
if self.environment.identity_type != IdentityType.KEY:
374-
adi_skill.auth_identity = (
375+
layout_analysis_skill.auth_identity = (
375376
self.environment.function_app_app_registration_resource_id
376377
)
377378

378379
if self.environment.identity_type == IdentityType.USER_ASSIGNED:
379-
adi_skill.auth_identity = self.environment.ai_search_user_assigned_identity
380+
layout_analysis_skill.auth_identity = (
381+
self.environment.ai_search_user_assigned_identity
382+
)
380383

381-
return adi_skill
384+
return layout_analysis_skill
382385

383-
def get_vector_skill(
384-
self, context, source, target_name="vector"
385-
) -> AzureOpenAIEmbeddingSkill:
386-
"""Get the vector skill for the indexer.
386+
def get_figure_analysis_skill(self, figure_source) -> WebApiSkill:
387+
"""Get the custom skill for figure analysis.
388+
389+
Args:
390+
-----
391+
chunk_by_page (bool, optional): Whether to chunk by page. Defaults to False.
387392
388393
Returns:
389-
AzureOpenAIEmbeddingSkill: The vector skill for the indexer"""
394+
--------
395+
WebApiSkill: The custom skill for adi"""
390396

391-
embedding_skill_inputs = [
392-
InputFieldMappingEntry(name="text", source=source),
393-
]
394-
embedding_skill_outputs = [
395-
OutputFieldMappingEntry(name="embedding", target_name=target_name)
397+
if self.test:
398+
batch_size = 1
399+
degree_of_parallelism = 4
400+
else:
401+
# Depending on your GPT Token limit, you may need to adjust the batch size and degree of parallelism
402+
batch_size = 1
403+
degree_of_parallelism = 8
404+
405+
output = [
406+
OutputFieldMappingEntry(name="updated_figure", target_name="updated_figure")
396407
]
397408

398-
vector_skill = AzureOpenAIEmbeddingSkill(
399-
name="Vector Skill",
400-
description="Skill to generate embeddings",
401-
context=context,
402-
deployment_name=self.environment.open_ai_embedding_deployment,
403-
model_name=self.environment.open_ai_embedding_model,
404-
resource_url=self.environment.open_ai_endpoint,
405-
inputs=embedding_skill_inputs,
406-
outputs=embedding_skill_outputs,
407-
dimensions=self.environment.open_ai_embedding_dimensions,
409+
figure_analysis_skill = WebApiSkill(
410+
name="Figure Analysis Skill",
411+
description="Skill to generate figure analysis",
412+
context=figure_source,
413+
uri=self.environment.get_custom_skill_function_url("figure_analysis"),
414+
timeout="PT230S",
415+
batch_size=batch_size,
416+
degree_of_parallelism=degree_of_parallelism,
417+
http_method="POST",
418+
inputs=[InputFieldMappingEntry(name="figure", source=figure_source)],
419+
outputs=output,
408420
)
409421

410-
if self.environment.identity_type == IdentityType.KEY:
411-
vector_skill.api_key = self.environment.open_ai_api_key
412-
elif self.environment.identity_type == IdentityType.USER_ASSIGNED:
413-
vector_skill.auth_identity = (
422+
if self.environment.identity_type != IdentityType.KEY:
423+
figure_analysis_skill.auth_identity = (
424+
self.environment.function_app_app_registration_resource_id
425+
)
426+
427+
if self.environment.identity_type == IdentityType.USER_ASSIGNED:
428+
figure_analysis_skill.auth_identity = (
414429
self.environment.ai_search_user_assigned_identity
415430
)
416431

417-
return vector_skill
432+
return figure_analysis_skill
418433

419-
def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill:
420-
"""Get the key phrase extraction skill.
434+
def get_layout_and_figure_merger_skill(self, figure_source) -> WebApiSkill:
435+
"""Get the custom skill for layout and figure merger.
421436
422437
Args:
423438
-----
424-
context (str): The context of the skill
425-
source (str): The source of the skill
439+
chunk_by_page (bool, optional): Whether to chunk by page. Defaults to False.
426440
427441
Returns:
428442
--------
429-
WebApiSkill: The key phrase extraction skill"""
443+
WebApiSkill: The custom skill for adi"""
430444

431445
if self.test:
432-
batch_size = 4
446+
batch_size = 1
433447
degree_of_parallelism = 4
434448
else:
435-
batch_size = 16
436-
degree_of_parallelism = 16
449+
# Depending on your GPT Token limit, you may need to adjust the batch size and degree of parallelism
450+
batch_size = 1
451+
degree_of_parallelism = 8
437452

438-
key_phrase_extraction_skill_inputs = [
439-
InputFieldMappingEntry(name="text", source=source),
440-
]
441-
key_phrase_extraction__skill_outputs = [
442-
OutputFieldMappingEntry(name="key_phrases", target_name="keywords")
453+
output = [
454+
OutputFieldMappingEntry(name="updated_figure", target_name="updated_figure")
443455
]
444-
key_phrase_extraction_skill = WebApiSkill(
445-
name="Key phrase extraction API",
446-
description="Skill to extract keyphrases",
447-
context=context,
448-
uri=self.environment.get_custom_skill_function_url("key_phrase_extraction"),
456+
457+
figure_analysis_skill = WebApiSkill(
458+
name="Layout and Figure Merger Skill",
459+
description="Skill to merge layout and figure analysis",
460+
context=figure_source,
461+
uri=self.environment.get_custom_skill_function_url(
462+
"layout_and_figure_merger"
463+
),
449464
timeout="PT230S",
450465
batch_size=batch_size,
451466
degree_of_parallelism=degree_of_parallelism,
452467
http_method="POST",
453-
inputs=key_phrase_extraction_skill_inputs,
454-
outputs=key_phrase_extraction__skill_outputs,
468+
inputs=[InputFieldMappingEntry(name="figure", source=figure_source)],
469+
outputs=output,
455470
)
456471

457472
if self.environment.identity_type != IdentityType.KEY:
458-
key_phrase_extraction_skill.auth_identity = (
473+
figure_analysis_skill.auth_identity = (
459474
self.environment.function_app_app_registration_resource_id
460475
)
461476

462477
if self.environment.identity_type == IdentityType.USER_ASSIGNED:
463-
key_phrase_extraction_skill.auth_identity = (
478+
figure_analysis_skill.auth_identity = (
479+
self.environment.ai_search_user_assigned_identity
480+
)
481+
482+
return figure_analysis_skill
483+
484+
def get_vector_skill(
485+
self, context, source, target_name="vector"
486+
) -> AzureOpenAIEmbeddingSkill:
487+
"""Get the vector skill for the indexer.
488+
489+
Returns:
490+
AzureOpenAIEmbeddingSkill: The vector skill for the indexer"""
491+
492+
embedding_skill_inputs = [
493+
InputFieldMappingEntry(name="text", source=source),
494+
]
495+
embedding_skill_outputs = [
496+
OutputFieldMappingEntry(name="embedding", target_name=target_name)
497+
]
498+
499+
vector_skill = AzureOpenAIEmbeddingSkill(
500+
name="Vector Skill",
501+
description="Skill to generate embeddings",
502+
context=context,
503+
deployment_name=self.environment.open_ai_embedding_deployment,
504+
model_name=self.environment.open_ai_embedding_model,
505+
resource_url=self.environment.open_ai_endpoint,
506+
inputs=embedding_skill_inputs,
507+
outputs=embedding_skill_outputs,
508+
dimensions=self.environment.open_ai_embedding_dimensions,
509+
)
510+
511+
if self.environment.identity_type == IdentityType.KEY:
512+
vector_skill.api_key = self.environment.open_ai_api_key
513+
elif self.environment.identity_type == IdentityType.USER_ASSIGNED:
514+
vector_skill.auth_identity = (
464515
self.environment.ai_search_user_assigned_identity
465516
)
466517

467-
return key_phrase_extraction_skill
518+
return vector_skill
468519

469520
def get_vector_search(self) -> VectorSearch:
470521
"""Get the vector search configuration for compass.

deploy_ai_search/src/deploy_ai_search/rag_documents.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def get_index_fields(self) -> list[SearchableField]:
9898
facetable=True,
9999
),
100100
ComplexField(
101-
name="Figures",
101+
name="ChunkFigures",
102102
collection=True,
103103
fields=[
104104
SearchableField(
@@ -107,31 +107,11 @@ def get_index_fields(self) -> list[SearchableField]:
107107
collection=True,
108108
searchable=False,
109109
),
110-
SimpleField(
111-
name="Container",
112-
type=SearchFieldDataType.String,
113-
filterable=True,
114-
),
115-
SimpleField(
116-
name="ImageBlob",
117-
type=SearchFieldDataType.String,
118-
filterable=True,
119-
),
120110
SimpleField(
121111
name="Caption",
122112
type=SearchFieldDataType.String,
123113
filterable=True,
124114
),
125-
SimpleField(
126-
name="Offset",
127-
type=SearchFieldDataType.Int64,
128-
filterable=True,
129-
),
130-
SimpleField(
131-
name="Length",
132-
type=SearchFieldDataType.Int64,
133-
filterable=True,
134-
),
135115
SimpleField(
136116
name="PageNumber",
137117
type=SearchFieldDataType.Int64,
@@ -258,16 +238,7 @@ def get_index_projections(self) -> SearchIndexerIndexProjection:
258238
),
259239
InputFieldMappingEntry(
260240
name="Figures",
261-
source_context="/document/chunks/*/figures/*",
262-
inputs=[
263-
InputFieldMappingEntry(
264-
name="FigureId", source="/document/chunks/*/figures/*/figure_id"
265-
),
266-
InputFieldMappingEntry(
267-
name="FigureUri",
268-
source="/document/chunks/*/figures/*/figure_uri",
269-
),
270-
],
241+
source_context="/document/chunks/*/chunk_figures/*",
271242
),
272243
InputFieldMappingEntry(
273244
name="DateLastModified", source="/document/DateLastModified"

image_processing/src/image_processing/layout_analysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ async def process_figures_from_extracted_content(
247247
)
248248
)
249249

250-
image_blob = f"{self.blob}/{figure.id}.png"
250+
blob = f"{self.blob}/{figure.id}.png"
251251

252252
caption = (
253253
figure.caption.content if figure.caption is not None else None
@@ -257,15 +257,15 @@ async def process_figures_from_extracted_content(
257257
uri = "{}/{}/{}".format(
258258
storage_account_helper.account_url,
259259
self.images_container,
260-
image_blob,
260+
blob,
261261
)
262262

263263
offset = figure.spans[0].offset - text_holder.page_offsets
264264

265265
image_processing_data = FigureHolder(
266266
figure_id=figure.id,
267267
container=self.images_container,
268-
image_blob=image_blob,
268+
blob=blob,
269269
caption=caption,
270270
offset=offset,
271271
length=figure.spans[0].length,
@@ -293,7 +293,7 @@ async def process_figures_from_extracted_content(
293293
figure_upload_tasks.append(
294294
storage_account_helper.upload_blob(
295295
figure_processing_data.container,
296-
figure_processing_data.image_blob,
296+
figure_processing_data.blob,
297297
image_data,
298298
"image/png",
299299
)

image_processing/src/image_processing/layout_holders.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@ class FigureHolder(BaseModel):
99

1010
"""A class to hold the figure extracted from the document."""
1111

12-
figure_id: str = Field(alias="FigureId")
13-
container: str = Field(default="Container")
14-
image_blob: str = Field(default="ImageBlob")
15-
caption: Optional[str] = Field(default=None, alias="Caption")
16-
offset: int = Field(alias="Offset")
17-
length: int = Field(alias="Length")
18-
page_number: Optional[int] = Field(default=None, alias="PageNumber")
19-
uri: str = Field(alias="Uri")
20-
description: Optional[str] = Field(default="", alias="Description")
21-
data: Optional[str] = Field(default=None, alias="Data")
12+
figure_id: str
13+
container: str = Field(exclude=True)
14+
blob: str = Field(exclude=True)
15+
caption: Optional[str] = Field(default=None)
16+
offset: int
17+
length: int
18+
page_number: Optional[int] = Field(default=None)
19+
uri: str
20+
description: Optional[str] = Field(default="")
21+
data: Optional[str] = Field(default=None)
2222

2323
@property
2424
def markdown(self) -> str:

image_processing/src/image_processing/mark_down_cleaner.py renamed to image_processing/src/image_processing/mark_up_cleaner.py

File renamed without changes.

0 commit comments

Comments
 (0)