-
Notifications
You must be signed in to change notification settings - Fork 1
feat: Add translation logic from UUID to Source ID and back in Indicator Selection Chain #172 #173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
veryniceuser
wants to merge
5
commits into
development
Choose a base branch
from
feature/translate_dataset_uuids_for_indicator_selection_prompt
base: development
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
fb69d3e
Add translation logic from UUID to Source ID and back in Indicator Se…
DzmitryVabishchewichTR b18afa7
Revert field renaming: dataset_id_to_name
veryniceuser a999972
refactor: rename dataset ID translation methods for clarity
veryniceuser 40466b8
feat: separate candidates retrieval and dataset ids mapping
veryniceuser ca18f2e
Merge remote-tracking branch 'origin/development' into feature/transl…
veryniceuser File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,12 +67,19 @@ def combine_with_priority(self): | |
|
|
||
|
|
||
| class IndicatorCandidatesLLMFormatter: | ||
| def __init__(self, dataset_id_2_name: dict[str, str]): | ||
| self.dataset_id_2_name = dataset_id_2_name | ||
| def __init__( | ||
| self, | ||
| dataset_alias_to_name: dict[str, str], | ||
| dataset_id_to_source_id: dict[str, str] | None = None, | ||
| ): | ||
| self.dataset_alias_to_name = dataset_alias_to_name | ||
| self._dataset_id_to_source_id = dataset_id_to_source_id or {} | ||
|
|
||
| @classmethod | ||
| def get_candidate_details_by_dataset( | ||
| cls, candidates: list[ScoredIndicatorCandidate] | ||
| cls, | ||
| candidates: list[ScoredIndicatorCandidate], | ||
| dataset_id_to_source_id: dict[str, str] | None = None, | ||
| ) -> DatasetDimensionTermNameType: | ||
| cand_by_dataset: dict[str, list[ScoredIndicatorCandidate]] = {} | ||
| for c in candidates: | ||
|
|
@@ -102,13 +109,14 @@ def get_candidate_details_by_dataset( | |
| for details in first_ind_comp_details | ||
| } | ||
|
|
||
| dataset_data[dataset_id] = cur_dataset_dimensions | ||
| key = (dataset_id_to_source_id or {}).get(dataset_id, dataset_id) | ||
| dataset_data[key] = cur_dataset_dimensions | ||
| return dataset_data | ||
|
|
||
| def _data2text(self, candidate_details_by_dataset: DatasetDimensionTermNameType) -> str: | ||
| lines = [] | ||
| for dataset_id, dimension_data in candidate_details_by_dataset.items(): | ||
| dataset_name = self.dataset_id_2_name[dataset_id] | ||
| dataset_name = self.dataset_alias_to_name[dataset_id] | ||
| lines.append( | ||
| f'Dataset id: "{dataset_id}", dataset name: "{dataset_name}". ' | ||
| f'Dimensions (keys are dimension IDs):' | ||
|
|
@@ -119,7 +127,9 @@ def _data2text(self, candidate_details_by_dataset: DatasetDimensionTermNameType) | |
| return res | ||
|
|
||
| def run(self, candidates: list[ScoredIndicatorCandidate]): | ||
| data = self.get_candidate_details_by_dataset(candidates) | ||
| data = self.get_candidate_details_by_dataset( | ||
| candidates, dataset_id_to_source_id=self._dataset_id_to_source_id | ||
| ) | ||
| res = self._data2text(data) | ||
| return res | ||
|
|
||
|
|
@@ -129,6 +139,16 @@ class LLMResponseBase(BaseModel, ABC): | |
| def get_queries(self) -> DatasetDimQueriesType: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def translate_dataset_ids(self, source_id_to_dataset_id: dict[str, str]) -> None: | ||
|
||
| """Replace source_id keys in LLM-produced queries with dataset UUIDs. | ||
|
|
||
| The LLM receives short source_ids (e.g. "IMF:IFS(1.0)") instead of | ||
| dataset UUIDs to avoid garbling. This method translates the keys back | ||
| to UUIDs so downstream pipeline steps (populate_stage, _remove_hallucinations) | ||
| can match them against datasets_dict. | ||
| """ | ||
|
|
||
| async def populate_stage(self, inputs: dict) -> None: | ||
| logger.info(f'default {type(self).__name__}.populate_stage(): doing nothing') | ||
|
|
||
|
|
@@ -178,7 +198,7 @@ class LLMResponse(LLMResponseBase): | |
| dataset_queries: DatasetDimQueriesType = Field( | ||
| default={}, | ||
| description=( | ||
| 'mapping from dataset id (not name!!!) to query. ' | ||
| 'mapping from dataset id (as shown in candidates, not name!) to query. ' | ||
| 'query is a mapping from dimension id ' | ||
| 'to list of dimension value ids required in the user query. ' | ||
| ), | ||
|
|
@@ -187,6 +207,12 @@ class LLMResponse(LLMResponseBase): | |
| def get_queries(self) -> DatasetDimQueriesType: | ||
| return self.dataset_queries | ||
|
|
||
| def translate_dataset_ids(self, source_id_to_dataset_id: dict[str, str]) -> None: | ||
| """V1 stores queries in a flat dict keyed by dataset id.""" | ||
| self.dataset_queries = { | ||
| source_id_to_dataset_id.get(k, k): v for k, v in self.dataset_queries.items() | ||
| } | ||
|
|
||
| async def populate_stage(self, inputs: dict) -> None: | ||
| state = ChainParameters.get_state(inputs) | ||
| if not state.get(StateVarsConfig.SHOW_DEBUG_STAGES): | ||
|
|
@@ -262,6 +288,18 @@ async def _populate_candidates_stage(self, inputs: dict): | |
| content = f'```yaml\n{candidates_formatted}\n```' | ||
| stage.append_content(content) | ||
|
|
||
| def _translate_aliases(self, inputs: dict): | ||
| """Chain step: convert source_id keys in the LLM response back to dataset UUIDs. | ||
|
|
||
| MUST be called right after the LLM call and before populate_stage / | ||
| _remove_hallucinations, which expect dataset UUIDs as keys. | ||
| """ | ||
| source_id_to_dataset_id = inputs.get('source_id_to_dataset_id', {}) | ||
| if source_id_to_dataset_id: | ||
| parsed_response: LLMResponseBase = inputs[self.PARSED_RESPONSE_KEY] | ||
| parsed_response.translate_dataset_ids(source_id_to_dataset_id) | ||
| return inputs | ||
|
|
||
| def _create_chain_inner(self, llm): | ||
| async def async_lambda(inputs): | ||
| return await inputs[self.PARSED_RESPONSE_KEY].populate_stage(inputs) | ||
|
|
@@ -270,6 +308,7 @@ async def async_lambda(inputs): | |
| self._format_candidates | ||
| | RunnablePassthrough.assign(_=self._populate_candidates_stage) | ||
| | RunnablePassthrough.assign(**{self.PARSED_RESPONSE_KEY: self._prompt_template | llm}) | ||
| | self._translate_aliases # call it right after llm | ||
| | RunnablePassthrough.assign(_=async_lambda) | ||
| | self._remove_hallucinations | ||
| ) | ||
|
|
@@ -303,10 +342,34 @@ def _format_candidates(self, inputs: dict): | |
| candidates = self._get_candidates(inputs) | ||
| chain_state = ChainState(**inputs) | ||
| datasets_dict = chain_state.datasets_dict | ||
| dataset_id_2_name = {ds.data.entity_id: ds.data.name for ds in datasets_dict.values()} | ||
| formatter = IndicatorCandidatesLLMFormatter(dataset_id_2_name=dataset_id_2_name) | ||
|
|
||
| # Build bidirectional mapping: dataset UUID <-> source_id (short URN) | ||
| # Source IDs are used as LLM-facing identifiers to avoid fragile UUIDs | ||
| dataset_id_to_source_id: dict[str, str] = {} | ||
| source_id_to_dataset_id: dict[str, str] = {} | ||
|
|
||
| for entity_id, ds in datasets_dict.items(): | ||
| source_id = ds.data.source_id | ||
| assert source_id is not None, f'Dataset {entity_id} has no source_id' | ||
| if source_id in source_id_to_dataset_id: | ||
| logger.warning( | ||
| f'Duplicate source_id "{source_id}" for datasets ' | ||
| f'{source_id_to_dataset_id[source_id]} and {entity_id}' | ||
| ) | ||
| dataset_id_to_source_id[entity_id] = source_id | ||
| source_id_to_dataset_id[source_id] = entity_id | ||
|
|
||
| dataset_source_id_to_name = { | ||
| ds.data.source_id: ds.data.name for ds in datasets_dict.values() | ||
| } | ||
|
|
||
| formatter = IndicatorCandidatesLLMFormatter( | ||
| dataset_alias_to_name=dataset_source_id_to_name, | ||
| dataset_id_to_source_id=dataset_id_to_source_id, | ||
| ) | ||
| text = formatter.run(candidates) | ||
| inputs['yaml_candidates'] = text | ||
| inputs['source_id_to_dataset_id'] = source_id_to_dataset_id | ||
| return inputs | ||
|
|
||
| def _remove_hallucinations(self, inputs: dict) -> DatasetDimQueries: | ||
|
|
@@ -486,6 +549,18 @@ def get_queries(self) -> DatasetDimQueriesType: | |
| res = self.queries.combine_with_priority().queries | ||
| return res | ||
|
|
||
| def translate_dataset_ids(self, source_id_to_dataset_id: dict[str, str]) -> None: | ||
| """V2 splits queries into exact/child relevancy sub-dicts, | ||
| so both must be translated independently.""" | ||
|
|
||
| def _translate(q: DatasetDimQueriesType) -> DatasetDimQueriesType: | ||
| return {source_id_to_dataset_id.get(k, k): v for k, v in q.items()} | ||
|
|
||
| self.queries = DatasetDimQueriesByRelevancy( | ||
| exact=DatasetDimQueries(queries=_translate(self.queries.exact.queries)), | ||
| child=DatasetDimQueries(queries=_translate(self.queries.child.queries)), | ||
| ) | ||
|
|
||
| async def populate_stage(self, inputs: dict) -> None: | ||
| state = ChainParameters.get_state(inputs) | ||
| data_service = ChainParameters.get_data_service(inputs) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's avoid using "alias" term - to avoid confusion with
dimension_aliasdataset config field