Skip to content

Commit 20a96cb

Browse files
authored
Init command asks for models (#2137)
* Add init prompting for models * Remove hard-coded model config validation * Switch to typer option prompt for full CLI use with models * Update getting started for init model input * Bump request timeout and overall smoke test timeout
1 parent 7bf82b7 commit 20a96cb

File tree

11 files changed

+35
-61
lines changed

11 files changed

+35
-61
lines changed

docs/get_started.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ To initialize your workspace, first run the `graphrag init` command.
4747
graphrag init
4848
```
4949

50+
When prompted, specify the default chat and embedding models you would like to use in your config.
51+
5052
This will create two files, `.env` and `settings.yaml`, and a directory `input`, in the current directory.
5153

5254
- `input` Location of text files to process with `graphrag`.

packages/graphrag/graphrag/api/index.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ async def build_index(
3030
config: GraphRagConfig,
3131
method: IndexingMethod | str = IndexingMethod.Standard,
3232
is_update_run: bool = False,
33-
memory_profile: bool = False,
3433
callbacks: list[WorkflowCallbacks] | None = None,
3534
additional_context: dict[str, Any] | None = None,
3635
verbose: bool = False,
@@ -67,9 +66,6 @@ async def build_index(
6766

6867
outputs: list[PipelineRunResult] = []
6968

70-
if memory_profile:
71-
logger.warning("New pipeline does not yet support memory profiling.")
72-
7369
logger.info("Initializing indexing pipeline...")
7470
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
7571
method = _get_method(method, is_update_run)

packages/graphrag/graphrag/cli/index.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def index_cli(
4343
root_dir: Path,
4444
method: IndexingMethod,
4545
verbose: bool,
46-
memprofile: bool,
4746
cache: bool,
4847
dry_run: bool,
4948
skip_validation: bool,
@@ -55,7 +54,6 @@ def index_cli(
5554
method=method,
5655
is_update_run=False,
5756
verbose=verbose,
58-
memprofile=memprofile,
5957
cache=cache,
6058
dry_run=dry_run,
6159
skip_validation=skip_validation,
@@ -66,7 +64,6 @@ def update_cli(
6664
root_dir: Path,
6765
method: IndexingMethod,
6866
verbose: bool,
69-
memprofile: bool,
7067
cache: bool,
7168
skip_validation: bool,
7269
):
@@ -80,7 +77,6 @@ def update_cli(
8077
method=method,
8178
is_update_run=True,
8279
verbose=verbose,
83-
memprofile=memprofile,
8480
cache=cache,
8581
dry_run=False,
8682
skip_validation=skip_validation,
@@ -92,7 +88,6 @@ def _run_index(
9288
method,
9389
is_update_run,
9490
verbose,
95-
memprofile,
9691
cache,
9792
dry_run,
9893
skip_validation,
@@ -129,7 +124,6 @@ def _run_index(
129124
config=config,
130125
method=method,
131126
is_update_run=is_update_run,
132-
memory_profile=memprofile,
133127
callbacks=[ConsoleWorkflowCallbacks(verbose=verbose)],
134128
verbose=verbose,
135129
)

packages/graphrag/graphrag/cli/initialize.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
logger = logging.getLogger(__name__)
3636

3737

38-
def initialize_project_at(path: Path, force: bool) -> None:
38+
def initialize_project_at(
39+
path: Path, force: bool, model: str, embedding_model: str
40+
) -> None:
3941
"""
4042
Initialize the project at the given path.
4143
@@ -64,8 +66,11 @@ def initialize_project_at(path: Path, force: bool) -> None:
6466
root / (graphrag_config_defaults.input.storage.base_dir or "input")
6567
).resolve()
6668
input_path.mkdir(parents=True, exist_ok=True)
67-
68-
settings_yaml.write_text(INIT_YAML, encoding="utf-8", errors="strict")
69+
# using replace with custom tokens instead of format here because we have a placeholder for GRAPHRAG_API_KEY that is used later for .env overlay
70+
formatted = INIT_YAML.replace("<DEFAULT_CHAT_MODEL>", model).replace(
71+
"<DEFAULT_EMBEDDING_MODEL>", embedding_model
72+
)
73+
settings_yaml.write_text(formatted, encoding="utf-8", errors="strict")
6974

7075
dotenv = root / ".env"
7176
if not dotenv.exists() or force:

packages/graphrag/graphrag/cli/main.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
import typer
1212

13-
from graphrag.config.defaults import graphrag_config_defaults
13+
from graphrag.config.defaults import (
14+
DEFAULT_CHAT_MODEL,
15+
DEFAULT_EMBEDDING_MODEL,
16+
graphrag_config_defaults,
17+
)
1418
from graphrag.config.enums import IndexingMethod, SearchMethod
1519
from graphrag.prompt_tune.defaults import LIMIT, MAX_TOKEN_COUNT, N_SUBSET_MAX, K
1620
from graphrag.prompt_tune.types import DocSelectionType
@@ -104,6 +108,18 @@ def _initialize_cli(
104108
resolve_path=True,
105109
autocompletion=ROOT_AUTOCOMPLETE,
106110
),
111+
model: str = typer.Option(
112+
DEFAULT_CHAT_MODEL,
113+
"--model",
114+
"-m",
115+
prompt="Specify the default chat model to use",
116+
),
117+
embedding_model: str = typer.Option(
118+
DEFAULT_EMBEDDING_MODEL,
119+
"--embedding",
120+
"-e",
121+
prompt="Specify the default embedding model to use",
122+
),
107123
force: bool = typer.Option(
108124
False,
109125
"--force",
@@ -114,7 +130,9 @@ def _initialize_cli(
114130
"""Generate a default configuration file."""
115131
from graphrag.cli.initialize import initialize_project_at
116132

117-
initialize_project_at(path=root, force=force)
133+
initialize_project_at(
134+
path=root, force=force, model=model, embedding_model=embedding_model
135+
)
118136

119137

120138
@app.command("index")
@@ -143,11 +161,6 @@ def _index_cli(
143161
"-v",
144162
help="Run the indexing pipeline with verbose logging",
145163
),
146-
memprofile: bool = typer.Option(
147-
False,
148-
"--memprofile",
149-
help="Run the indexing pipeline with memory profiling",
150-
),
151164
dry_run: bool = typer.Option(
152165
False,
153166
"--dry-run",
@@ -173,7 +186,6 @@ def _index_cli(
173186
index_cli(
174187
root_dir=root,
175188
verbose=verbose,
176-
memprofile=memprofile,
177189
cache=cache,
178190
dry_run=dry_run,
179191
skip_validation=skip_validation,
@@ -207,11 +219,6 @@ def _update_cli(
207219
"-v",
208220
help="Run the indexing pipeline with verbose logging.",
209221
),
210-
memprofile: bool = typer.Option(
211-
False,
212-
"--memprofile",
213-
help="Run the indexing pipeline with memory profiling.",
214-
),
215222
cache: bool = typer.Option(
216223
True,
217224
"--cache/--no-cache",
@@ -233,7 +240,6 @@ def _update_cli(
233240
update_cli(
234241
root_dir=root,
235242
verbose=verbose,
236-
memprofile=memprofile,
237243
cache=cache,
238244
skip_validation=skip_validation,
239245
method=method,

packages/graphrag/graphrag/config/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ class LanguageModelDefaults:
274274
n: int = 1
275275
frequency_penalty: float = 0.0
276276
presence_penalty: float = 0.0
277-
request_timeout: float = 180.0
277+
request_timeout: float = 600.0
278278
api_base: None = None
279279
api_version: None = None
280280
deployment_name: None = None

packages/graphrag/graphrag/config/errors.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,6 @@ def __init__(self, llm_type: str) -> None:
3333
super().__init__(msg)
3434

3535

36-
class LanguageModelConfigMissingError(ValueError):
37-
"""Missing model configuration error."""
38-
39-
def __init__(self, key: str = "") -> None:
40-
"""Init method definition."""
41-
msg = f'A {key} model configuration is required. Please rerun `graphrag init` and set models["{key}"] in settings.yaml.'
42-
super().__init__(msg)
43-
44-
4536
class ConflictingSettingsError(ValueError):
4637
"""Missing model configuration error."""
4738

packages/graphrag/graphrag/config/init_content.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
2424
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
2525
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
26-
model: {defs.DEFAULT_CHAT_MODEL}
26+
model: <DEFAULT_CHAT_MODEL>
2727
# api_base: https://<instance>.openai.azure.com
2828
# api_version: 2024-05-01-preview
2929
model_supports_json: true # recommended if this is available for your model.
@@ -37,7 +37,7 @@
3737
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
3838
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
3939
api_key: ${{GRAPHRAG_API_KEY}}
40-
model: {defs.DEFAULT_EMBEDDING_MODEL}
40+
model: <DEFAULT_EMBEDDING_MODEL>
4141
# api_base: https://<instance>.openai.azure.com
4242
# api_version: 2024-05-01-preview
4343
concurrent_requests: {language_model_defaults.concurrent_requests}

packages/graphrag/graphrag/config/models/graph_rag_config.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import graphrag.config.defaults as defs
1212
from graphrag.config.defaults import graphrag_config_defaults
1313
from graphrag.config.enums import VectorStoreType
14-
from graphrag.config.errors import LanguageModelConfigMissingError
1514
from graphrag.config.models.basic_search_config import BasicSearchConfig
1615
from graphrag.config.models.cache_config import CacheConfig
1716
from graphrag.config.models.chunking_config import ChunkingConfig
@@ -58,24 +57,6 @@ def __str__(self):
5857
default=graphrag_config_defaults.models,
5958
)
6059

61-
def _validate_models(self) -> None:
62-
"""Validate the models configuration.
63-
64-
Ensure both a default chat model and default embedding model
65-
have been defined. Other models may also be defined but
66-
defaults are required for the time being as places of the
67-
code fallback to default model configs instead
68-
of specifying a specific model.
69-
70-
TODO: Don't fallback to default models elsewhere in the code.
71-
Forcing code to specify a model to use and allowing for any
72-
names for model configurations.
73-
"""
74-
if defs.DEFAULT_CHAT_MODEL_ID not in self.models:
75-
raise LanguageModelConfigMissingError(defs.DEFAULT_CHAT_MODEL_ID)
76-
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
77-
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)
78-
7960
def _validate_retry_services(self) -> None:
8061
"""Validate the retry services configuration."""
8162
retry_factory = RetryFactory()
@@ -329,7 +310,6 @@ def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
329310
@model_validator(mode="after")
330311
def _validate_model(self):
331312
"""Validate the model configuration."""
332-
self._validate_models()
333313
self._validate_input_pattern()
334314
self._validate_input_base_dir()
335315
self._validate_reporting_base_dir()

tests/fixtures/min-csv/config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"period",
5252
"size"
5353
],
54-
"max_runtime": 1200,
54+
"max_runtime": 2000,
5555
"expected_artifacts": ["community_reports.parquet"]
5656
},
5757
"create_final_text_units": {

0 commit comments

Comments
 (0)