Skip to content

Commit d054feb

Browse files
committed
evaluation model added to config
1 parent 21c756e commit d054feb

File tree

2 files changed

+79
-23
lines changed

2 files changed

+79
-23
lines changed

lib/idp_common_pkg/idp_common/config/models.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,46 @@ def parse_int(cls, v: Any) -> int:
375375
return int(v)
376376

377377

378+
class EvaluationLLMMethodConfig(BaseModel):
379+
"""Evaluation LLM method configuration"""
380+
381+
top_p: float = Field(default=0.1, ge=0.0, le=1.0)
382+
max_tokens: int = Field(default=4096, gt=0)
383+
top_k: float = Field(default=5.0, ge=0.0)
384+
task_prompt: str = Field(default="", description="Task prompt for evaluation")
385+
temperature: float = Field(default=0.0, ge=0.0, le=1.0)
386+
model: str = Field(
387+
default="us.anthropic.claude-3-haiku-20240307-v1:0",
388+
description="Bedrock model ID for evaluation",
389+
)
390+
system_prompt: str = Field(default="", description="System prompt for evaluation")
391+
392+
@field_validator("temperature", "top_p", "top_k", mode="before")
393+
@classmethod
394+
def parse_float(cls, v: Any) -> float:
395+
"""Parse float from string or number"""
396+
if isinstance(v, str):
397+
return float(v) if v else 0.0
398+
return float(v)
399+
400+
@field_validator("max_tokens", mode="before")
401+
@classmethod
402+
def parse_int(cls, v: Any) -> int:
403+
"""Parse int from string or number"""
404+
if isinstance(v, str):
405+
return int(v) if v else 0
406+
return int(v)
407+
408+
409+
class EvaluationConfig(BaseModel):
410+
"""Evaluation configuration for assessment"""
411+
412+
llm_method: EvaluationLLMMethodConfig = Field(
413+
default_factory=EvaluationLLMMethodConfig,
414+
description="LLM method configuration for evaluation",
415+
)
416+
417+
378418
class DiscoveryModelConfig(BaseModel):
379419
"""Discovery model configuration for class extraction"""
380420

@@ -498,6 +538,9 @@ class IDPConfig(BaseModel):
498538
discovery: DiscoveryConfig = Field(
499539
default_factory=DiscoveryConfig, description="Discovery configuration"
500540
)
541+
evaluation: EvaluationConfig = Field(
542+
default_factory=EvaluationConfig, description="Evaluation configuration"
543+
)
501544

502545
model_config = ConfigDict(
503546
# Do not allow extra fields - all config should be explicit

lib/idp_common_pkg/idp_common/evaluation/service.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
import time
1515
import traceback
1616
from concurrent.futures import ThreadPoolExecutor
17-
from typing import Any, Dict, Generator, List, Optional, Tuple
17+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
1818

1919
from idp_common import s3
20+
from idp_common.config.models import IDPConfig
2021
from idp_common.config.schema_constants import (
2122
SCHEMA_DESCRIPTION,
2223
SCHEMA_ITEMS,
@@ -45,37 +46,45 @@ class EvaluationService:
4546
"""Service for evaluating document extraction results."""
4647

4748
def __init__(
48-
self, region: str = None, config: Dict[str, Any] = None, max_workers: int = 10
49+
self,
50+
region: str = None,
51+
config: Union[Dict[str, Any], IDPConfig] = None,
52+
max_workers: int = 10,
4953
):
5054
"""
5155
Initialize the evaluation service.
5256
5357
Args:
5458
region: AWS region
55-
config: Configuration dictionary containing evaluation settings
59+
config: Configuration dictionary or IDPConfig model containing evaluation settings
5660
max_workers: Maximum number of concurrent workers for section evaluation
5761
"""
58-
self.config = config or {}
59-
self.region = (
60-
region or self.config.get("region") or os.environ.get("AWS_REGION")
61-
)
62+
# Convert dict to IDPConfig if needed
63+
if config is not None and isinstance(config, dict):
64+
config_model: IDPConfig = IDPConfig(**config)
65+
elif config is None:
66+
config_model = IDPConfig()
67+
else:
68+
config_model = config
69+
70+
self.config = config_model
71+
self.region = region or os.environ.get("AWS_REGION")
6272
self.max_workers = max_workers
6373

64-
# Set default LLM evaluation settings
65-
self.llm_config = self.config.get("evaluation", {}).get("llm_method", {})
66-
self.default_model = self.llm_config.get(
67-
"model", "anthropic.claude-3-sonnet-20240229-v1:0"
68-
)
69-
self.default_temperature = self.llm_config.get("temperature", 0.0)
70-
self.default_top_k = self.llm_config.get("top_k", 5)
71-
self.default_system_prompt = self.llm_config.get(
72-
"system_prompt",
73-
"""You are an evaluator that helps determine if the predicted and expected values match for document attribute extraction. You will consider the context and meaning rather than just exact string matching.""",
74+
# Set default LLM evaluation settings from typed config
75+
self.default_model = self.config.evaluation.llm_method.model
76+
self.default_temperature = self.config.evaluation.llm_method.temperature
77+
self.default_top_k = self.config.evaluation.llm_method.top_k
78+
self.default_top_p = self.config.evaluation.llm_method.top_p
79+
self.default_max_tokens = self.config.evaluation.llm_method.max_tokens
80+
self.default_system_prompt = (
81+
self.config.evaluation.llm_method.system_prompt
82+
or """You are an evaluator that helps determine if the predicted and expected values match for document attribute extraction. You will consider the context and meaning rather than just exact string matching."""
7483
)
7584

76-
self.default_task_prompt = self.llm_config.get(
77-
"task_prompt",
78-
"""I need to evaluate attribute extraction for a document of class: {DOCUMENT_CLASS}.
85+
self.default_task_prompt = (
86+
self.config.evaluation.llm_method.task_prompt
87+
or """I need to evaluate attribute extraction for a document of class: {DOCUMENT_CLASS}.
7988
8089
For the attribute named "{ATTRIBUTE_NAME}" described as "{ATTRIBUTE_DESCRIPTION}":
8190
- Expected value: {EXPECTED_VALUE}
@@ -93,7 +102,7 @@ def __init__(
93102
"score": 0.0 to 1.0,
94103
"reason": "Your explanation here"
95104
}
96-
""",
105+
"""
97106
)
98107

99108
logger.info(
@@ -111,9 +120,13 @@ def _get_attributes_for_class(self, class_name: str) -> List[EvaluationAttribute
111120
Returns:
112121
List of attribute configurations (flattened for nested structures)
113122
"""
114-
classes = self.config.get("classes", [])
123+
classes = self.config.classes
115124
for schema in classes:
116-
if schema.get(X_AWS_IDP_DOCUMENT_TYPE, "").lower() == class_name.lower():
125+
if (
126+
isinstance(schema, dict)
127+
and schema.get(X_AWS_IDP_DOCUMENT_TYPE, "").lower()
128+
== class_name.lower()
129+
):
117130
properties = schema.get(SCHEMA_PROPERTIES, {})
118131
return list(self._walk_properties(properties))
119132

0 commit comments

Comments
 (0)