Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,19 @@ def combine_with_priority(self):


class IndicatorCandidatesLLMFormatter:
def __init__(self, dataset_id_2_name: dict[str, str]):
self.dataset_id_2_name = dataset_id_2_name
def __init__(
self,
dataset_alias_to_name: dict[str, str],
dataset_id_to_source_id: dict[str, str] | None = None,
):
self.dataset_alias_to_name = dataset_alias_to_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's avoid using "alias" term - to avoid confusion with dimension_alias dataset config field

self._dataset_id_to_source_id = dataset_id_to_source_id or {}

@classmethod
def get_candidate_details_by_dataset(
cls, candidates: list[ScoredIndicatorCandidate]
cls,
candidates: list[ScoredIndicatorCandidate],
dataset_id_to_source_id: dict[str, str] | None = None,
) -> DatasetDimensionTermNameType:
cand_by_dataset: dict[str, list[ScoredIndicatorCandidate]] = {}
for c in candidates:
Expand Down Expand Up @@ -102,13 +109,14 @@ def get_candidate_details_by_dataset(
for details in first_ind_comp_details
}

dataset_data[dataset_id] = cur_dataset_dimensions
key = (dataset_id_to_source_id or {}).get(dataset_id, dataset_id)
dataset_data[key] = cur_dataset_dimensions
return dataset_data

def _data2text(self, candidate_details_by_dataset: DatasetDimensionTermNameType) -> str:
lines = []
for dataset_id, dimension_data in candidate_details_by_dataset.items():
dataset_name = self.dataset_id_2_name[dataset_id]
dataset_name = self.dataset_alias_to_name[dataset_id]
lines.append(
f'Dataset id: "{dataset_id}", dataset name: "{dataset_name}". '
f'Dimensions (keys are dimension IDs):'
Expand All @@ -119,7 +127,9 @@ def _data2text(self, candidate_details_by_dataset: DatasetDimensionTermNameType)
return res

def run(self, candidates: list[ScoredIndicatorCandidate]):
data = self.get_candidate_details_by_dataset(candidates)
data = self.get_candidate_details_by_dataset(
candidates, dataset_id_to_source_id=self._dataset_id_to_source_id
)
res = self._data2text(data)
return res

Expand All @@ -129,6 +139,16 @@ class LLMResponseBase(BaseModel, ABC):
def get_queries(self) -> DatasetDimQueriesType:
pass

@abstractmethod
def translate_dataset_ids(self, source_id_to_dataset_id: dict[str, str]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use translate_forward and translate_backward for clarity

"""Replace source_id keys in LLM-produced queries with dataset UUIDs.

The LLM receives short source_ids (e.g. "IMF:IFS(1.0)") instead of
dataset UUIDs to avoid garbling. This method translates the keys back
to UUIDs so downstream pipeline steps (populate_stage, _remove_hallucinations)
can match them against datasets_dict.
"""

async def populate_stage(self, inputs: dict) -> None:
logger.info(f'default {type(self).__name__}.populate_stage(): doing nothing')

Expand Down Expand Up @@ -178,7 +198,7 @@ class LLMResponse(LLMResponseBase):
dataset_queries: DatasetDimQueriesType = Field(
default={},
description=(
'mapping from dataset id (not name!!!) to query. '
'mapping from dataset id (as shown in candidates, not name!) to query. '
'query is a mapping from dimension id '
'to list of dimension value ids required in the user query. '
),
Expand All @@ -187,6 +207,12 @@ class LLMResponse(LLMResponseBase):
def get_queries(self) -> DatasetDimQueriesType:
return self.dataset_queries

def translate_dataset_ids(self, source_id_to_dataset_id: dict[str, str]) -> None:
"""V1 stores queries in a flat dict keyed by dataset id."""
self.dataset_queries = {
source_id_to_dataset_id.get(k, k): v for k, v in self.dataset_queries.items()
}

async def populate_stage(self, inputs: dict) -> None:
state = ChainParameters.get_state(inputs)
if not state.get(StateVarsConfig.SHOW_DEBUG_STAGES):
Expand Down Expand Up @@ -262,6 +288,18 @@ async def _populate_candidates_stage(self, inputs: dict):
content = f'```yaml\n{candidates_formatted}\n```'
stage.append_content(content)

def _translate_aliases(self, inputs: dict):
"""Chain step: convert source_id keys in the LLM response back to dataset UUIDs.

MUST be called right after the LLM call and before populate_stage /
_remove_hallucinations, which expect dataset UUIDs as keys.
"""
source_id_to_dataset_id = inputs.get('source_id_to_dataset_id', {})
if source_id_to_dataset_id:
parsed_response: LLMResponseBase = inputs[self.PARSED_RESPONSE_KEY]
parsed_response.translate_dataset_ids(source_id_to_dataset_id)
return inputs

def _create_chain_inner(self, llm):
async def async_lambda(inputs):
return await inputs[self.PARSED_RESPONSE_KEY].populate_stage(inputs)
Expand All @@ -270,6 +308,7 @@ async def async_lambda(inputs):
self._format_candidates
| RunnablePassthrough.assign(_=self._populate_candidates_stage)
| RunnablePassthrough.assign(**{self.PARSED_RESPONSE_KEY: self._prompt_template | llm})
| self._translate_aliases # call it right after llm
| RunnablePassthrough.assign(_=async_lambda)
| self._remove_hallucinations
)
Expand Down Expand Up @@ -303,10 +342,34 @@ def _format_candidates(self, inputs: dict):
candidates = self._get_candidates(inputs)
chain_state = ChainState(**inputs)
datasets_dict = chain_state.datasets_dict
dataset_id_2_name = {ds.data.entity_id: ds.data.name for ds in datasets_dict.values()}
formatter = IndicatorCandidatesLLMFormatter(dataset_id_2_name=dataset_id_2_name)

# Build bidirectional mapping: dataset UUID <-> source_id (short URN)
# Source IDs are used as LLM-facing identifiers to avoid fragile UUIDs
dataset_id_to_source_id: dict[str, str] = {}
source_id_to_dataset_id: dict[str, str] = {}

for entity_id, ds in datasets_dict.items():
source_id = ds.data.source_id
assert source_id is not None, f'Dataset {entity_id} has no source_id'
if source_id in source_id_to_dataset_id:
logger.warning(
f'Duplicate source_id "{source_id}" for datasets '
f'{source_id_to_dataset_id[source_id]} and {entity_id}'
)
dataset_id_to_source_id[entity_id] = source_id
source_id_to_dataset_id[source_id] = entity_id

dataset_source_id_to_name = {
ds.data.source_id: ds.data.name for ds in datasets_dict.values()
}

formatter = IndicatorCandidatesLLMFormatter(
dataset_alias_to_name=dataset_source_id_to_name,
dataset_id_to_source_id=dataset_id_to_source_id,
)
text = formatter.run(candidates)
inputs['yaml_candidates'] = text
inputs['source_id_to_dataset_id'] = source_id_to_dataset_id
return inputs

def _remove_hallucinations(self, inputs: dict) -> DatasetDimQueries:
Expand Down Expand Up @@ -486,6 +549,18 @@ def get_queries(self) -> DatasetDimQueriesType:
res = self.queries.combine_with_priority().queries
return res

def translate_dataset_ids(self, source_id_to_dataset_id: dict[str, str]) -> None:
"""V2 splits queries into exact/child relevancy sub-dicts,
so both must be translated independently."""

def _translate(q: DatasetDimQueriesType) -> DatasetDimQueriesType:
return {source_id_to_dataset_id.get(k, k): v for k, v in q.items()}

self.queries = DatasetDimQueriesByRelevancy(
exact=DatasetDimQueries(queries=_translate(self.queries.exact.queries)),
child=DatasetDimQueries(queries=_translate(self.queries.child.queries)),
)

async def populate_stage(self, inputs: dict) -> None:
state = ChainParameters.get_state(inputs)
data_service = ChainParameters.get_data_service(inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ class LLMResponse(LLMResponseBase):
def get_queries(self):
pass # not used

def translate_dataset_ids(self, source_id_to_dataset_id: dict[str, str]) -> None:
pass # V3 uses indexes, not dataset IDs

class CombinedOutput(BaseModel):
queries: DatasetDimQueries
selection_status_str: str
Expand Down
Loading