Skip to content

Commit 977f5ea

Browse files
committed
synthesizer ui updated
1 parent 686ee75 commit 977f5ea

File tree

5 files changed

+125
-137
lines changed

5 files changed

+125
-137
lines changed

backend/llm_eval/qa_catalog/generator/implementation/ragas/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class RagasQACatalogGeneratorPersona(ApiModel):
2626
class RagasQACatalogGeneratorConfig(QACatalogGeneratorConfig[RagasGeneratorType]):
2727
knowledge_graph_location: Path | None
2828
sample_count: int
29-
query_distribution: dict[RagasQACatalogQuerySynthesizer, float]
29+
query_distribution: list[RagasQACatalogQuerySynthesizer]
3030
personas: list[RagasQACatalogGeneratorPersona] | None
3131

3232
use_existing_knowledge_graph: bool = True

backend/llm_eval/qa_catalog/generator/implementation/ragas/generator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,9 @@ def __init__(
138138
)
139139

140140
def validate_config(self) -> None:
141-
if sum(self.config.query_distribution.values()) != 1:
141+
if not self.config.query_distribution:
142142
raise ValueError(
143-
"Given query distribution for the generation is invalid, "
144-
"distribution weights should sum up to 1"
143+
"At least one query synthesizer must be selected for QA generation"
145144
)
146145

147146
def _load_and_process_documents(self) -> list[Document]:
@@ -325,7 +324,7 @@ def create_query_distribution(
325324
properties = ["headlines", "keyphrases", "entities"]
326325

327326
selected_synthesizer_classes = (
328-
query_synthesizer_classes[q] for q in self.config.query_distribution.keys()
327+
query_synthesizer_classes[q] for q in self.config.query_distribution
329328
)
330329

331330
synthesizers = []

frontend/app/[locale]/(authenticated)/qa-catalogs/generate/page.test.tsx

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import "@/app/test-utils/mock-intl";
12
import "@/app/test-utils/mock-router";
23
import "@/app/test-utils/mock-toast";
3-
import "@/app/test-utils/mock-intl";
44

55
import { addToast } from "@heroui/react";
66
import {
@@ -24,6 +24,7 @@ import {
2424
QaCatalogGenerationConfig,
2525
QaCatalogGenerationModelConfigurationSchema,
2626
qaCatalogGetGeneratorTypes,
27+
RagasQaCatalogQuerySynthesizer,
2728
} from "@/app/client";
2829
import {
2930
clearComboBox,
@@ -105,7 +106,7 @@ describe("Synthetic QA Catalog Generation Page", () => {
105106

106107
await clearComboBox(user, ragasLabel("llmEndpointId"));
107108

108-
await inputDistributionsForRagas(testCase.configuration, true);
109+
await inputDistributionsForRagas(testCase.configuration);
109110
}
110111
};
111112

@@ -143,13 +144,13 @@ describe("Synthetic QA Catalog Generation Page", () => {
143144
configuration: {
144145
type: "RAGAS",
145146
personas: [],
146-
queryDistribution: {
147-
MULTI_HOP_ABSTRACT: 0,
148-
MULTI_HOP_SPECIFIC: 0.5,
149-
SINGLE_HOP_SPECIFIC: 0.5,
150-
},
147+
queryDistribution: [
148+
RagasQaCatalogQuerySynthesizer.MULTI_HOP_SPECIFIC,
149+
RagasQaCatalogQuerySynthesizer.SINGLE_HOP_SPECIFIC,
150+
],
151151
sampleCount: 5,
152152
knowledgeGraphLocation: null,
153+
useExistingKnowledgeGraph: true,
153154
},
154155
modelConfigSchema: {
155156
type: "RAGAS",
@@ -168,20 +169,21 @@ describe("Synthetic QA Catalog Generation Page", () => {
168169
{ name: "p-1", description: "description-1" },
169170
{ name: "p-2", description: "description-2" },
170171
],
171-
queryDistribution: {
172-
MULTI_HOP_ABSTRACT: 0.2,
173-
MULTI_HOP_SPECIFIC: 0.4,
174-
SINGLE_HOP_SPECIFIC: 0.4,
175-
},
172+
queryDistribution: [
173+
RagasQaCatalogQuerySynthesizer.MULTI_HOP_ABSTRACT,
174+
RagasQaCatalogQuerySynthesizer.MULTI_HOP_SPECIFIC,
175+
RagasQaCatalogQuerySynthesizer.SINGLE_HOP_SPECIFIC,
176+
],
176177
sampleCount: 5,
177178
knowledgeGraphLocation: null,
179+
useExistingKnowledgeGraph: true,
178180
},
179181
modelConfigSchema: {
180182
type: "RAGAS",
181183
llmEndpoint: "llm-1",
182184
},
183185
dataSourceConfigId: "data-source-config-2",
184-
files: createFiles(["file-1", "file-2"]),
186+
files: createFiles(["file-3", "file-4"]),
185187
},
186188
];
187189

@@ -396,20 +398,50 @@ describe("Synthetic QA Catalog Generation Page", () => {
396398

397399
const inputDistributionsForRagas = async (
398400
configuration: QaCatalogGenerationConfig,
399-
clear: boolean = false,
400401
) => {
401402
if (configuration.type == "RAGAS") {
402-
const distributions = configuration.queryDistribution;
403-
404-
for (const [synth, weight] of Object.entries(distributions)) {
405-
const slider = screen.getByTestId(`queryDistributionSlider-${synth}`);
406-
const sliderInput = within(slider).getByRole("slider", {
407-
hidden: true,
408-
});
403+
const selectedSynthesizers = configuration.queryDistribution;
404+
405+
// First, uncheck all checkboxes to start from clean state
406+
const allSynthesizers = [
407+
RagasQaCatalogQuerySynthesizer.MULTI_HOP_ABSTRACT,
408+
RagasQaCatalogQuerySynthesizer.MULTI_HOP_SPECIFIC,
409+
RagasQaCatalogQuerySynthesizer.SINGLE_HOP_SPECIFIC,
410+
];
411+
412+
for (const synth of allSynthesizers) {
413+
const checkbox = screen.queryByTestId(
414+
`queryDistributionCheckbox-${synth}`,
415+
);
416+
if (checkbox) {
417+
// Check if checkbox is currently selected (either checked attribute or data-selected attribute)
418+
const isChecked =
419+
checkbox.getAttribute("checked") !== null ||
420+
checkbox.getAttribute("data-selected") === "true";
421+
422+
if (isChecked) {
423+
fireEvent.click(checkbox);
424+
// Wait a bit for the state to update
425+
await new Promise((resolve) => setTimeout(resolve, 10));
426+
}
427+
}
428+
}
409429

410-
fireEvent.change(sliderInput, {
411-
target: { value: clear ? 0 : weight },
412-
});
430+
// Then check the ones we want
431+
for (const synth of selectedSynthesizers) {
432+
const checkbox = screen.getByTestId(
433+
`queryDistributionCheckbox-${synth}`,
434+
);
435+
// Check if checkbox is currently unselected
436+
const isChecked =
437+
checkbox.getAttribute("checked") !== null ||
438+
checkbox.getAttribute("data-selected") === "true";
439+
440+
if (!isChecked) {
441+
fireEvent.click(checkbox);
442+
// Wait a bit for the state to update
443+
await new Promise((resolve) => setTimeout(resolve, 10));
444+
}
413445
}
414446
}
415447
};

frontend/app/[locale]/(authenticated)/qa-catalogs/plugins/implementations/ragas.tsx

Lines changed: 44 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ import {
22
Accordion,
33
AccordionItem,
44
Button,
5-
Slider,
5+
Checkbox,
6+
CheckboxGroup,
67
Textarea,
78
} from "@heroui/react";
89
import { cx } from "classix";
@@ -30,16 +31,9 @@ const synthesizerTypes = Object.values(RagasQaCatalogQuerySynthesizer) as [
3031

3132
const synthesizerTypeEnum = z.enum(synthesizerTypes);
3233

33-
const getDefaultSynthesizerValues = (): {
34-
[key: string]: number;
35-
} =>
36-
synthesizerTypes.reduce(
37-
(acc, t) => {
38-
acc[t] = 0;
39-
return acc;
40-
},
41-
{} as { [key: string]: number },
42-
);
34+
const getDefaultSynthesizerValues = (): RagasQaCatalogQuerySynthesizer[] => [
35+
RagasQaCatalogQuerySynthesizer.SINGLE_HOP_SPECIFIC,
36+
];
4337

4438
const ragasGeneratorConfigurationShape = {
4539
config: z.object({
@@ -50,18 +44,9 @@ const ragasGeneratorConfigurationShape = {
5044
})
5145
.int({ message: formErrors.int })
5246
.min(1, { message: formErrors.required }),
53-
queryDistribution: z.record(synthesizerTypeEnum, z.number()).refine(
54-
(distribution) => {
55-
const sum = Object.values(distribution).reduce(
56-
(acc, val) => acc + (val == undefined ? 0 : val),
57-
0,
58-
);
59-
return Math.abs(sum - 1) < Number.EPSILON;
60-
},
61-
{
62-
message: "The sum of all values in queryDistribution must equal 1",
63-
},
64-
),
47+
queryDistribution: z
48+
.array(synthesizerTypeEnum)
49+
.min(1, { message: "At least one query synthesizer must be selected" }),
6550
personas: z
6651
.array(
6752
z.object({
@@ -220,46 +205,38 @@ const QueryDistributionForm = ({
220205
</span>
221206
</div>
222207
<div className="flex flex-col space-y-3">
223-
{synthesizerTypes.map((synth) => (
224-
<Controller
225-
name={`config.queryDistribution.${synth}`}
226-
control={control}
227-
key={synth}
228-
render={({ field, formState: { errors } }) => (
229-
<div className="flex flex-col space-x-2 align-items-center items-end justify-center">
230-
<Slider
231-
defaultValue={field.value}
232-
label={t(
233-
`RagasQACatalogGeneratorConfigurationForm.field.queryDistribution.values.${synth}.title`,
234-
)}
235-
maxValue={1}
236-
minValue={0}
237-
step={0.1}
238-
showSteps={true}
239-
size={"sm"}
240-
onChange={(v) => field.onChange(v)}
241-
color={distributionError ? "danger" : "primary"}
242-
data-testid={`queryDistributionSlider-${synth}`}
243-
/>
244-
<span className="text-default-500 text-sm text-left w-full">
245-
{t(
246-
`RagasQACatalogGeneratorConfigurationForm.field.queryDistribution.values.${synth}.description`,
247-
)}
248-
</span>
249-
{errors.config?.queryDistribution?.[synth] && (
250-
<span className="text-danger text-sm">
251-
{errors.config.queryDistribution[synth].message}
208+
<Controller
209+
name="config.queryDistribution"
210+
control={control}
211+
render={({ field }) => (
212+
<CheckboxGroup
213+
value={field.value}
214+
onValueChange={field.onChange}
215+
orientation="vertical"
216+
color="primary"
217+
isInvalid={queryDistributionInvalid}
218+
errorMessage={distributionError?.message}
219+
>
220+
{synthesizerTypes.map((synth) => (
221+
<div key={synth} className="flex flex-col space-y-1">
222+
<Checkbox
223+
value={synth}
224+
data-testid={`queryDistributionCheckbox-${synth}`}
225+
>
226+
{t(
227+
`RagasQACatalogGeneratorConfigurationForm.field.queryDistribution.values.${synth}.title`,
228+
)}
229+
</Checkbox>
230+
<span className="text-default-500 text-sm ml-6">
231+
{t(
232+
`RagasQACatalogGeneratorConfigurationForm.field.queryDistribution.values.${synth}.description`,
233+
)}
252234
</span>
253-
)}
254-
</div>
255-
)}
256-
/>
257-
))}
258-
{distributionError && (
259-
<span className="text-danger text-sm">
260-
{distributionError.root?.message}
261-
</span>
262-
)}
235+
</div>
236+
))}
237+
</CheckboxGroup>
238+
)}
239+
/>
263240
</div>
264241
</div>
265242
);
@@ -324,11 +301,12 @@ export const ragasGeneratorPlugin = createQACatalogGenerationPlugin({
324301
configurationForm: RagasGeneratorConfigurationForm,
325302
getDefaults: () => ({
326303
config: {
304+
type: "RAGAS" as const,
327305
sampleCount: 5,
328-
queryDistribution: {
329-
...getDefaultSynthesizerValues(),
330-
SINGLE_HOP_SPECIFIC: 1,
331-
},
306+
queryDistribution: getDefaultSynthesizerValues(),
307+
personas: [],
308+
knowledgeGraphLocation: null,
309+
useExistingKnowledgeGraph: true,
332310
},
333311
modelConfig: { llmEndpoint: null },
334312
}),

0 commit comments

Comments
 (0)