Skip to content
This repository was archived by the owner on May 27, 2025. It is now read-only.

Commit 84f2770

Browse files
committed
update how prompts get saved in cosmosdb
1 parent 72d759a commit 84f2770

File tree

7 files changed

+95
-75
lines changed

7 files changed

+95
-75
lines changed

backend/src/api/index.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ async def schedule_indexing_job(
4949
storage_name: str,
5050
index_name: str,
5151
entity_extraction_prompt: UploadFile | None = None,
52-
community_report_prompt: UploadFile | None = None,
53-
summarize_descriptions_prompt: UploadFile | None = None,
52+
entity_summarization_prompt: UploadFile | None = None,
53+
community_summarization_prompt: UploadFile | None = None,
5454
):
5555
azure_client_manager = AzureClientManager()
5656
blob_service_client = azure_client_manager.get_blob_service_client()
@@ -80,14 +80,14 @@ async def schedule_indexing_job(
8080
if entity_extraction_prompt
8181
else None
8282
)
83-
community_report_prompt_content = (
84-
community_report_prompt.file.read().decode("utf-8")
85-
if community_report_prompt
83+
entity_summarization_prompt_content = (
84+
entity_summarization_prompt.file.read().decode("utf-8")
85+
if entity_summarization_prompt
8686
else None
8787
)
88-
summarize_descriptions_prompt_content = (
89-
summarize_descriptions_prompt.file.read().decode("utf-8")
90-
if summarize_descriptions_prompt
88+
community_summarization_prompt_content = (
89+
community_summarization_prompt.file.read().decode("utf-8")
90+
if community_summarization_prompt
9191
else None
9292
)
9393

@@ -116,9 +116,9 @@ async def schedule_indexing_job(
116116
existing_job._failed_workflows
117117
) = []
118118
existing_job._entity_extraction_prompt = entity_extraction_prompt_content
119-
existing_job._community_report_prompt = community_report_prompt_content
120-
existing_job._summarize_descriptions_prompt = (
121-
summarize_descriptions_prompt_content
119+
existing_job._entity_summarization_prompt = entity_summarization_prompt_content
120+
existing_job._community_summarization_prompt = (
121+
community_summarization_prompt_content
122122
)
123123
existing_job._epoch_request_time = int(time())
124124
existing_job.update_db()
@@ -128,8 +128,8 @@ async def schedule_indexing_job(
128128
human_readable_index_name=index_name,
129129
human_readable_storage_name=storage_name,
130130
entity_extraction_prompt=entity_extraction_prompt_content,
131-
community_report_prompt=community_report_prompt_content,
132-
summarize_descriptions_prompt=summarize_descriptions_prompt_content,
131+
entity_summarization_prompt=entity_summarization_prompt_content,
132+
community_summarization_prompt=community_summarization_prompt_content,
133133
status=PipelineJobState.SCHEDULED,
134134
)
135135

backend/src/indexer/__init__.py

Whitespace-only changes.

backend/src/indexer/indexer.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@
1212
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1313
from graphrag.config.create_graphrag_config import create_graphrag_config
1414
from graphrag.index.create_pipeline_config import create_pipeline_config
15+
from graphrag.index.typing import PipelineRunResult
1516

16-
from src.logger import (
17+
from ...src.logger import (
1718
PipelineJobUpdater,
1819
load_pipeline_logger,
1920
)
20-
from src.typing.pipeline import PipelineJobState
21-
from src.utils.azure_clients import AzureClientManager
22-
from src.utils.common import sanitize_name
23-
from src.utils.pipeline import PipelineJob
21+
from ...src.typing.pipeline import PipelineJobState
22+
from ...src.utils.azure_clients import AzureClientManager
23+
from ...src.utils.common import sanitize_name
24+
from ...src.utils.pipeline import PipelineJob
2425

2526

2627
def start_indexing_job(index_name: str):
28+
return 0
2729
print("Start indexing job...")
2830
# get sanitized name
2931
sanitized_index_name = sanitize_name(index_name)
@@ -73,20 +75,20 @@ def start_indexing_job(index_name: str):
7375
else:
7476
data.pop("entity_extraction")
7577

76-
# set prompt for summarize descriptions
77-
if pipeline_job.summarize_descriptions_prompt:
78-
fname = "summarize-descriptions-prompt.txt"
78+
# set prompt for entity summarization
79+
if pipeline_job.entity_summarization_prompt:
80+
fname = "entity-summarization-prompt.txt"
7981
with open(fname, "w") as outfile:
80-
outfile.write(pipeline_job.summarize_descriptions_prompt)
82+
outfile.write(pipeline_job.entity_summarization_prompt)
8183
data["summarize_descriptions"]["prompt"] = fname
8284
else:
8385
data.pop("summarize_descriptions")
8486

85-
# set prompt for community report
86-
if pipeline_job.community_report_prompt:
87-
fname = "community-report-prompt.txt"
87+
# set prompt for community summarization
88+
if pipeline_job.community_summarization_prompt:
89+
fname = "community-summarization-prompt.txt"
8890
with open(fname, "w") as outfile:
89-
outfile.write(pipeline_job.community_report_prompt)
91+
outfile.write(pipeline_job.community_summarization_prompt)
9092
data["community_reports"]["prompt"] = fname
9193
else:
9294
data.pop("community_reports")
@@ -101,7 +103,7 @@ def start_indexing_job(index_name: str):
101103
pipeline_job.failed_workflows = []
102104
pipeline_config = create_pipeline_config(parameters)
103105
for workflow in pipeline_config.workflows:
104-
pipeline_job.all_workflows.append(workflow.name)
106+
pipeline_job.all_workflows = pipeline_job.all_workflows.append(workflow.name)
105107

106108
# create new loggers/callbacks just for this job
107109
print("Creating generic loggers...")
@@ -118,16 +120,27 @@ def start_indexing_job(index_name: str):
118120
# run the pipeline
119121
try:
120122
print("Building index...")
121-
asyncio.run(
123+
pipeline_results: list[PipelineRunResult] = asyncio.run(
122124
api.build_index(
123125
config=parameters,
124126
callbacks=[logger, pipeline_job_updater],
125127
)
126128
)
127-
print("Index building complete")
128-
# if job is done, check if any pipeline steps failed
129+
130+
# once indexing job is done, check if any pipeline steps failed
131+
for result in pipeline_results:
132+
if result.errors:
133+
pipeline_job.failed_workflows = pipeline_job.failed_workflows.append(
134+
result.workflow
135+
)
136+
else:
137+
pipeline_job.completed_workflows = (
138+
pipeline_job.completed_workflows.append(result.workflow)
139+
)
140+
print("Indexing complete")
141+
129142
if len(pipeline_job.failed_workflows) > 0:
130-
print("Indexing pipeline encountered error.")
143+
print("Indexing pipeline encountered errors.")
131144
pipeline_job.status = PipelineJobState.FAILED
132145
logger.error(
133146
message=f"Indexing pipeline encountered error for index'{index_name}'.",
@@ -158,12 +171,10 @@ def start_indexing_job(index_name: str):
158171
exit(1) # signal to AKS that indexing job failed
159172
except Exception as e:
160173
pipeline_job.status = PipelineJobState.FAILED
161-
# update failed state in cosmos db
162174
error_details = {
163175
"index": index_name,
164176
"storage_name": storage_name,
165177
}
166-
# log error in local index directory logs
167178
logger.error(
168179
message=f"Indexing pipeline failed for index '{index_name}'.",
169180
cause=e,
@@ -177,8 +188,4 @@ def start_indexing_job(index_name: str):
177188
parser.add_argument("-i", "--index-name", required=True)
178189
args = parser.parse_args()
179190

180-
asyncio.run(
181-
start_indexing_job(
182-
index_name=args.index_name,
183-
)
184-
)
191+
start_indexing_job(index_name=args.index_name)

backend/src/indexer/settings.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ llm:
1313
api_version: $GRAPHRAG_API_VERSION
1414
model: $GRAPHRAG_LLM_MODEL
1515
deployment_name: $GRAPHRAG_LLM_DEPLOYMENT_NAME
16-
cognitive_services_endpoint: $GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT
16+
cognitive_services_endpoint: $COGNITIVE_SERVICES_AUDIENCE
1717
model_supports_json: True
1818
tokens_per_minute: 80_000
1919
requests_per_minute: 480
20-
concurrent_requests: 50
20+
concurrent_requests: 25
2121
max_retries: 250
2222
max_retry_wait: 60.0
2323
sleep_on_rate_limit_recommendation: True
@@ -43,7 +43,7 @@ embeddings:
4343
batch_size: 10
4444
model: $GRAPHRAG_EMBEDDING_MODEL
4545
deployment_name: $GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME
46-
cognitive_services_endpoint: $GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT
46+
cognitive_services_endpoint: $COGNITIVE_SERVICES_AUDIENCE
4747
tokens_per_minute: 350_000
4848
requests_per_minute: 2_100
4949

backend/src/logger/pipeline_job_updater.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@ def __init__(self, pipeline_job: PipelineJob):
2121
"""
2222
self._pipeline_job = pipeline_job
2323

24-
def on_workflow_start(self, name: str, instance: object) -> None:
24+
def workflow_start(self, name: str, instance: object) -> None:
2525
"""Execute this callback when a workflow starts."""
26-
# if we are not already running, set the status to running
27-
if self._pipeline_job.status != PipelineJobState.RUNNING:
28-
self._pipeline_job.status = PipelineJobState.RUNNING
26+
self._pipeline_job.status = PipelineJobState.RUNNING
2927
self._pipeline_job.progress = f"Workflow {name} started."
3028

31-
def on_workflow_end(self, name: str, instance: object) -> None:
29+
def workflow_end(self, name: str, instance: object) -> None:
3230
"""Execute this callback when a workflow ends."""
3331
self._pipeline_job.completed_workflows.append(name)
3432
self._pipeline_job.update_db()

backend/src/utils/pipeline.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,19 @@ class PipelineJob:
2727
_sanitized_index_name: str = field(default=None, init=False)
2828
_human_readable_storage_name: str = field(default=None, init=False)
2929
_sanitized_storage_name: str = field(default=None, init=False)
30-
_entity_extraction_prompt: str = field(default=None, init=False)
31-
_community_report_prompt: str = field(default=None, init=False)
32-
_summarize_descriptions_prompt: str = field(default=None, init=False)
30+
3331
_all_workflows: List[str] = field(default_factory=list, init=False)
3432
_completed_workflows: List[str] = field(default_factory=list, init=False)
3533
_failed_workflows: List[str] = field(default_factory=list, init=False)
34+
3635
_status: PipelineJobState = field(default=None, init=False)
3736
_percent_complete: float = field(default=0, init=False)
3837
_progress: str = field(default="", init=False)
3938

39+
_entity_extraction_prompt: str = field(default=None, init=False)
40+
_entity_summarization_prompt: str = field(default=None, init=False)
41+
_community_summarization_prompt: str = field(default=None, init=False)
42+
4043
@staticmethod
4144
def _jobs_container():
4245
azure_storage_client = AzureClientManager()
@@ -51,8 +54,8 @@ def create_item(
5154
human_readable_index_name: str,
5255
human_readable_storage_name: str,
5356
entity_extraction_prompt: str | None = None,
54-
community_report_prompt: str | None = None,
55-
summarize_descriptions_prompt: str | None = None,
57+
entity_summarization_prompt: str | None = None,
58+
community_summarization_prompt: str | None = None,
5659
**kwargs,
5760
) -> "PipelineJob":
5861
"""
@@ -95,18 +98,21 @@ def create_item(
9598
instance._sanitized_index_name = sanitize_name(human_readable_index_name)
9699
instance._human_readable_storage_name = human_readable_storage_name
97100
instance._sanitized_storage_name = sanitize_name(human_readable_storage_name)
98-
instance._entity_extraction_prompt = entity_extraction_prompt
99-
instance._community_report_prompt = community_report_prompt
100-
instance._summarize_descriptions_prompt = summarize_descriptions_prompt
101+
101102
instance._all_workflows = kwargs.get("all_workflows", [])
102103
instance._completed_workflows = kwargs.get("completed_workflows", [])
103104
instance._failed_workflows = kwargs.get("failed_workflows", [])
105+
104106
instance._status = PipelineJobState(
105107
kwargs.get("status", PipelineJobState.SCHEDULED.value)
106108
)
107109
instance._percent_complete = kwargs.get("percent_complete", 0.0)
108110
instance._progress = kwargs.get("progress", "")
109111

112+
instance._entity_extraction_prompt = entity_extraction_prompt
113+
instance._entity_summarization_prompt = entity_summarization_prompt
114+
instance._community_summarization_prompt = community_summarization_prompt
115+
110116
# Create the item in the database
111117
instance.update_db()
112118
return instance
@@ -140,17 +146,22 @@ def load_item(cls, id: str) -> "PipelineJob":
140146
"human_readable_storage_name"
141147
)
142148
instance._sanitized_storage_name = db_item.get("sanitized_storage_name")
143-
instance._entity_extraction_prompt = db_item.get("entity_extraction_prompt")
144-
instance._community_report_prompt = db_item.get("community_report_prompt")
145-
instance._summarize_descriptions_prompt = db_item.get(
146-
"summarize_descriptions_prompt"
147-
)
149+
148150
instance._all_workflows = db_item.get("all_workflows", [])
149151
instance._completed_workflows = db_item.get("completed_workflows", [])
150152
instance._failed_workflows = db_item.get("failed_workflows", [])
153+
151154
instance._status = PipelineJobState(db_item.get("status"))
152155
instance._percent_complete = db_item.get("percent_complete", 0.0)
153156
instance._progress = db_item.get("progress", "")
157+
158+
instance._entity_extraction_prompt = db_item.get("entity_extraction_prompt")
159+
instance._entity_summarization_prompt = db_item.get(
160+
"entity_summarization_prompt"
161+
)
162+
instance._community_summarization_prompt = db_item.get(
163+
"community_summarization_prompt"
164+
)
154165
return instance
155166

156167
@staticmethod
@@ -191,10 +202,12 @@ def dump_model(self) -> dict:
191202
}
192203
if self._entity_extraction_prompt:
193204
model["entity_extraction_prompt"] = self._entity_extraction_prompt
194-
if self._community_report_prompt:
195-
model["community_report_prompt"] = self._community_report_prompt
196-
if self._summarize_descriptions_prompt:
197-
model["summarize_descriptions_prompt"] = self._summarize_descriptions_prompt
205+
if self._entity_summarization_prompt:
206+
model["entity_summarization_prompt"] = self._entity_summarization_prompt
207+
if self._community_summarization_prompt:
208+
model["community_summarization_prompt"] = (
209+
self._community_summarization_prompt
210+
)
198211
return model
199212

200213
def update_db(self):
@@ -268,21 +281,23 @@ def entity_extraction_prompt(self, entity_extraction_prompt: str) -> None:
268281
self.update_db()
269282

270283
@property
271-
def community_report_prompt(self) -> str:
272-
return self._community_report_prompt
284+
def entity_summarization_prompt(self) -> str:
285+
return self._entity_summarization_prompt
273286

274-
@community_report_prompt.setter
275-
def community_report_prompt(self, community_report_prompt: str) -> None:
276-
self._community_report_prompt = community_report_prompt
287+
@entity_summarization_prompt.setter
288+
def entity_summarization_prompt(self, entity_summarization_prompt: str) -> None:
289+
self._entity_summarization_prompt = entity_summarization_prompt
277290
self.update_db()
278291

279292
@property
280-
def summarize_descriptions_prompt(self) -> str:
281-
return self._summarize_descriptions_prompt
282-
283-
@summarize_descriptions_prompt.setter
284-
def summarize_descriptions_prompt(self, summarize_descriptions_prompt: str) -> None:
285-
self._summarize_descriptions_prompt = summarize_descriptions_prompt
293+
def community_summarization_prompt(self) -> str:
294+
return self._community_summarization_prompt
295+
296+
@community_summarization_prompt.setter
297+
def community_summarization_prompt(
298+
self, community_summarization_prompt: str
299+
) -> None:
300+
self._community_summarization_prompt = community_summarization_prompt
286301
self.update_db()
287302

288303
@property

infra/helm/graphrag/values.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ graphragConfig:
3838
COSMOS_URI_ENDPOINT: ""
3939
GRAPHRAG_API_BASE: ""
4040
GRAPHRAG_API_VERSION: ""
41-
GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT: "https://cognitiveservices.azure.com/.default"
41+
COGNITIVE_SERVICES_AUDIENCE: "https://cognitiveservices.azure.com/.default"
4242
GRAPHRAG_LLM_MODEL: ""
4343
GRAPHRAG_LLM_DEPLOYMENT_NAME: ""
4444
GRAPHRAG_EMBEDDING_MODEL: ""

0 commit comments

Comments
 (0)