Skip to content
This repository was archived by the owner on May 27, 2025. It is now read-only.

Commit 0252646

Browse files
committed
refactor variable names to be more generic and add integration tests
1 parent ff5714a commit 0252646

File tree

3 files changed

+55
-32
lines changed

3 files changed

+55
-32
lines changed

backend/src/api/index_configuration.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33

44
import inspect
55
import os
6-
import shutil
76
import traceback
87

8+
import graphrag.api as api
99
import yaml
1010
from fastapi import (
1111
APIRouter,
1212
HTTPException,
1313
)
14-
from fastapi.responses import StreamingResponse
15-
from graphrag.prompt_tune.cli import prompt_tune as generate_fine_tune_prompts
14+
from graphrag.config.create_graphrag_config import create_graphrag_config
1615

1716
from src.api.azure_clients import AzureClientManager
1817
from src.api.common import (
@@ -27,7 +26,7 @@
2726

2827
@index_configuration_route.get(
2928
"/prompts",
30-
summary="Generate graphrag prompts from user-provided data.",
29+
summary="Generate prompts from user-provided data.",
3130
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
3231
)
3332
async def generate_prompts(storage_name: str, limit: int = 5):
@@ -44,29 +43,23 @@ async def generate_prompts(storage_name: str, limit: int = 5):
4443
status_code=500,
4544
detail=f"Data container '{storage_name}' does not exist.",
4645
)
46+
47+
# load pipeline configuration file (settings.yaml) for input data and other settings
4748
this_directory = os.path.dirname(
4849
os.path.abspath(inspect.getfile(inspect.currentframe()))
4950
)
50-
51-
# write custom settings.yaml to a file and store in a temporary directory
5251
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
5352
data["input"]["container_name"] = sanitized_storage_name
54-
temp_dir = f"/tmp/{sanitized_storage_name}_prompt_tuning"
55-
shutil.rmtree(temp_dir, ignore_errors=True)
56-
os.makedirs(temp_dir, exist_ok=True)
57-
with open(f"{temp_dir}/settings.yaml", "w") as f:
58-
yaml.dump(data, f, default_flow_style=False)
53+
graphrag_config = create_graphrag_config(values=data, root_dir=".")
5954

6055
# generate prompts
6156
try:
62-
await generate_fine_tune_prompts(
63-
config=f"{temp_dir}/settings.yaml",
64-
root=temp_dir,
65-
domain="",
66-
selection_method="random",
57+
# NOTE: we need to call api.generate_indexing_prompts
58+
prompts: tuple[str, str, str] = await api.generate_indexing_prompts(
59+
config=graphrag_config,
60+
root=".",
6761
limit=limit,
68-
skip_entity_types=True,
69-
output=f"{temp_dir}/prompts",
62+
selection_method="random",
7063
)
7164
except Exception as e:
7265
logger = LoggerSingleton().get_instance()
@@ -84,14 +77,9 @@ async def generate_prompts(storage_name: str, limit: int = 5):
8477
detail=f"Error generating prompts for data in '{storage_name}'. Please try a lower limit.",
8578
)
8679

87-
# zip up the generated prompt files and return the zip file
88-
temp_archive = (
89-
f"{temp_dir}/prompts" # will become a zip file with the name prompts.zip
90-
)
91-
shutil.make_archive(temp_archive, "zip", root_dir=temp_dir, base_dir="prompts")
92-
93-
def iterfile(file_path: str):
94-
with open(file_path, mode="rb") as file_like:
95-
yield from file_like
96-
97-
return StreamingResponse(iterfile(f"{temp_archive}.zip"))
80+
content = {
81+
"entity_extraction_prompt": prompts[0],
82+
"entity_summarization_prompt": prompts[1],
83+
"community_summarization_prompt": prompts[2],
84+
}
85+
return content # return a fastapi.responses.JSONResponse object

backend/tests/integration/test_utils_pipeline.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def cosmos_index_job_entry(cosmos_client) -> Generator[str, None, None]:
4242

4343

4444
def test_pipeline_job_interface(cosmos_index_job_entry):
45+
"""Test the src.utils.pipeline.PipelineJob class interface."""
4546
pipeline_job = PipelineJob()
47+
4648
# test creating a new entry
4749
pipeline_job.create_item(
4850
id="synthetic_id",
@@ -69,3 +71,36 @@ def test_pipeline_job_interface(cosmos_index_job_entry):
6971
assert pipeline_job.status == PipelineJobState.COMPLETE
7072
assert pipeline_job.percent_complete == 50.0
7173
assert pipeline_job.progress == "some progress"
74+
assert pipeline_job.calculate_percent_complete() == 50.0
75+
76+
# test setters and getters
77+
pipeline_job.id = "newID"
78+
assert pipeline_job.id == "newID"
79+
pipeline_job.epoch_request_time = 1
80+
assert pipeline_job.epoch_request_time == 1
81+
82+
pipeline_job.human_readable_index_name = "new_human_readable_index_name"
83+
assert pipeline_job.human_readable_index_name == "new_human_readable_index_name"
84+
pipeline_job.sanitized_index_name = "new_sanitized_index_name"
85+
assert pipeline_job.sanitized_index_name == "new_sanitized_index_name"
86+
87+
pipeline_job.human_readable_storage_name = "new_human_readable_storage_name"
88+
assert pipeline_job.human_readable_storage_name == "new_human_readable_storage_name"
89+
pipeline_job.sanitized_storage_name = "new_sanitized_storage_name"
90+
assert pipeline_job.sanitized_storage_name == "new_sanitized_storage_name"
91+
92+
pipeline_job.entity_extraction_prompt = "new_entity_extraction_prompt"
93+
assert pipeline_job.entity_extraction_prompt == "new_entity_extraction_prompt"
94+
pipeline_job.community_report_prompt = "new_community_report_prompt"
95+
assert pipeline_job.community_report_prompt == "new_community_report_prompt"
96+
pipeline_job.summarize_descriptions_prompt = "new_summarize_descriptions_prompt"
97+
assert pipeline_job.summarize_descriptions_prompt == "new_summarize_descriptions_prompt"
98+
99+
pipeline_job.all_workflows = ["new_workflow1", "new_workflow2", "new_workflow3"]
100+
assert len(pipeline_job.all_workflows) == 3
101+
102+
pipeline_job.completed_workflows = ["new_workflow1", "new_workflow2"]
103+
assert len(pipeline_job.completed_workflows) == 2
104+
105+
pipeline_job.failed_workflows = ["new_workflow3"]
106+
assert len(pipeline_job.failed_workflows) == 1

infra/deploy.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,12 +347,12 @@ deployAzureResources () {
347347
--resource-group $RESOURCE_GROUP \
348348
--template-file ./main.bicep \
349349
--parameters "resourceBaseName=$RESOURCE_BASE_NAME" \
350-
--parameters "graphRagName=$RESOURCE_GROUP" \
350+
--parameters "resourceGroupName=$RESOURCE_GROUP" \
351351
--parameters "apimName=$APIM_NAME" \
352352
--parameters "apimTier=$APIM_TIER" \
353-
--parameters "publisherName=$PUBLISHER_NAME" \
353+
--parameters "apiPublisherName=$PUBLISHER_NAME" \
354+
--parameters "apiPublisherEmail=$PUBLISHER_EMAIL" \
354355
--parameters "aksSshRsaPublicKey=$SSH_PUBLICKEY" \
355-
--parameters "publisherEmail=$PUBLISHER_EMAIL" \
356356
--parameters "enablePrivateEndpoints=$ENABLE_PRIVATE_ENDPOINTS" \
357357
--parameters "acrName=$CONTAINER_REGISTRY_NAME" \
358358
--parameters "deployerPrincipalId=$deployerPrincipalId" \

0 commit comments

Comments
 (0)