Skip to content

Commit de6da38

Browse files
authored
fix: prompt naming related issues (#1743)
1 parent a2a2cef commit de6da38

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

src/ragas/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def parse_run_traces(
164164
prompt_trace = traces[prompt_uuid]
165165
prompt_traces[f"{prompt_trace.name}"] = {
166166
"input": prompt_trace.inputs.get("data", {}),
167-
"output": prompt_trace.outputs.get("output", {}),
167+
"output": prompt_trace.outputs.get("output", {})[0],
168168
}
169169
metric_traces[f"{metric_trace.name}"] = prompt_traces
170170
parased_traces.append(metric_traces)

src/ragas/optimizers/genetic.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def dict_to_str(dict: t.Dict[str, t.Any]) -> str:
514514
exclude_none=True
515515
)
516516
),
517-
output=traces[idx][prompt_name]["output"][0].model_dump(
517+
output=traces[idx][prompt_name]["output"].model_dump(
518518
exclude_none=True
519519
),
520520
expected_output=dataset[idx]["prompts"][prompt_name][
@@ -586,14 +586,6 @@ def evaluate_candidate(
586586
_run_id=run_id,
587587
_pbar=parent_pbar,
588588
)
589-
# remap the traces to the original prompt names
590-
remap_traces = {val.name: key for key, val in self.metric.get_prompts().items()}
591-
for trace in results.traces:
592-
for key in remap_traces:
593-
if key in trace[self.metric.name]:
594-
trace[self.metric.name][remap_traces[key]] = trace[
595-
self.metric.name
596-
].pop(key)
597589
return results
598590

599591
def evaluate_fitness(

src/ragas/prompt/mixin.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@ class PromptMixin:
2121
eg: [BaseSynthesizer][ragas.testset.synthesizers.base.BaseSynthesizer], [MetricWithLLM][ragas.metrics.base.MetricWithLLM]
2222
"""
2323

24+
def _get_prompts(self) -> t.Dict[str, PydanticPrompt]:
25+
26+
prompts = {}
27+
for key, value in inspect.getmembers(self):
28+
if isinstance(value, PydanticPrompt):
29+
prompts.update({key: value})
30+
return prompts
31+
2432
def get_prompts(self) -> t.Dict[str, PydanticPrompt]:
2533
"""
2634
Returns a dictionary of prompts for the class.
2735
"""
2836
prompts = {}
29-
for name, value in inspect.getmembers(self):
30-
if isinstance(value, PydanticPrompt):
31-
prompts.update({name: value})
37+
for _, value in self._get_prompts().items():
38+
prompts.update({value.name: value})
3239
return prompts
3340

3441
def set_prompts(self, **prompts):
@@ -41,6 +48,7 @@ def set_prompts(self, **prompts):
4148
If the prompt is not an instance of `PydanticPrompt`.
4249
"""
4350
available_prompts = self.get_prompts()
51+
name_to_var = {v.name: k for k, v in self._get_prompts().items()}
4452
for key, value in prompts.items():
4553
if key not in available_prompts:
4654
raise ValueError(
@@ -50,7 +58,7 @@ def set_prompts(self, **prompts):
5058
raise ValueError(
5159
f"Prompt with name '{key}' must be an instance of 'ragas.prompt.PydanticPrompt'"
5260
)
53-
setattr(self, key, value)
61+
setattr(self, name_to_var[key], value)
5462

5563
async def adapt_prompts(
5664
self, language: str, llm: BaseRagasLLM, adapt_instruction: bool = False

0 commit comments

Comments
 (0)