33
44import inspect
55import os
6- import shutil
76import traceback
87
8+ import graphrag .api as api
99import yaml
1010from fastapi import (
1111 APIRouter ,
1212 HTTPException ,
1313)
14- from fastapi .responses import StreamingResponse
15- from graphrag .prompt_tune .cli import prompt_tune as generate_fine_tune_prompts
14+ from graphrag .config .create_graphrag_config import create_graphrag_config
1615
1716from src .api .azure_clients import AzureClientManager
1817from src .api .common import (
2726
2827@index_configuration_route .get (
2928 "/prompts" ,
30- summary = "Generate graphrag prompts from user-provided data." ,
29+ summary = "Generate prompts from user-provided data." ,
3130 description = "Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used." ,
3231)
3332async def generate_prompts (storage_name : str , limit : int = 5 ):
@@ -44,29 +43,23 @@ async def generate_prompts(storage_name: str, limit: int = 5):
4443 status_code = 500 ,
4544 detail = f"Data container '{ storage_name } ' does not exist." ,
4645 )
46+
47+ # load pipeline configuration file (settings.yaml) for input data and other settings
4748 this_directory = os .path .dirname (
4849 os .path .abspath (inspect .getfile (inspect .currentframe ()))
4950 )
50-
51- # write custom settings.yaml to a file and store in a temporary directory
5251 data = yaml .safe_load (open (f"{ this_directory } /pipeline-settings.yaml" ))
5352 data ["input" ]["container_name" ] = sanitized_storage_name
54- temp_dir = f"/tmp/{ sanitized_storage_name } _prompt_tuning"
55- shutil .rmtree (temp_dir , ignore_errors = True )
56- os .makedirs (temp_dir , exist_ok = True )
57- with open (f"{ temp_dir } /settings.yaml" , "w" ) as f :
58- yaml .dump (data , f , default_flow_style = False )
53+ graphrag_config = create_graphrag_config (values = data , root_dir = "." )
5954
6055 # generate prompts
6156 try :
62- await generate_fine_tune_prompts (
63- config = f"{ temp_dir } /settings.yaml" ,
64- root = temp_dir ,
65- domain = "" ,
66- selection_method = "random" ,
57+ # NOTE: we need to call api.generate_indexing_prompts
58+ prompts : tuple [str , str , str ] = await api .generate_indexing_prompts (
59+ config = graphrag_config ,
60+ root = "." ,
6761 limit = limit ,
68- skip_entity_types = True ,
69- output = f"{ temp_dir } /prompts" ,
62+ selection_method = "random" ,
7063 )
7164 except Exception as e :
7265 logger = LoggerSingleton ().get_instance ()
@@ -84,14 +77,9 @@ async def generate_prompts(storage_name: str, limit: int = 5):
8477 detail = f"Error generating prompts for data in '{ storage_name } '. Please try a lower limit." ,
8578 )
8679
87- # zip up the generated prompt files and return the zip file
88- temp_archive = (
89- f"{ temp_dir } /prompts" # will become a zip file with the name prompts.zip
90- )
91- shutil .make_archive (temp_archive , "zip" , root_dir = temp_dir , base_dir = "prompts" )
92-
93- def iterfile (file_path : str ):
94- with open (file_path , mode = "rb" ) as file_like :
95- yield from file_like
96-
97- return StreamingResponse (iterfile (f"{ temp_archive } .zip" ))
80+ content = {
81+ "entity_extraction_prompt" : prompts [0 ],
82+ "entity_summarization_prompt" : prompts [1 ],
83+ "community_summarization_prompt" : prompts [2 ],
84+ }
85+ return content # return a fastapi.responses.JSONResponse object
0 commit comments