Skip to content

Commit 686ee75

Browse files
committed
catalog generation unit tests fixed
1 parent 154c1c2 commit 686ee75

File tree

3 files changed

+87
-123
lines changed

3 files changed

+87
-123
lines changed

backend/llm_eval/qa_catalog/generator/implementation/ragas/generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def load_exiting_knowledge_graph(
303303
return None
304304

305305
logger.info(f"Knowledge graph found at {self.config.knowledge_graph_location}")
306+
knowledge_graph = None
306307
try:
307308
knowledge_graph = KnowledgeGraph.load(self.config.knowledge_graph_location)
308309
except Exception as e:

backend/tests/unit/backend/qa_catalog/logic/test_generation.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,8 @@ def catalog_generation_data(
124124
def session_local() -> MockSessionLocal:
125125
@contextmanager
126126
def _func(): # noqa: ANN202
127-
with patch(
128-
"llm_eval.qa_catalog.logic.generation.AsyncSessionLocal",
129-
) as session_local:
130-
mock_session = AsyncMock()
131-
session_local.begin.return_value.__aenter__.return_value = mock_session
132-
yield mock_session
127+
mock_session = AsyncMock()
128+
yield mock_session
133129

134130
return _func
135131

@@ -153,7 +149,7 @@ async def test_generate_catalog_task_happy_path(
153149
mock_find_qa_catalog.return_value = qa_catalog
154150
mock_find_data_source_config.return_value = temp_data_source_config
155151

156-
await generate_catalog(qa_catalog.id, catalog_generation_data)
152+
await generate_catalog(mock_session, qa_catalog.id, catalog_generation_data)
157153

158154
mock_find_qa_catalog.assert_called_once_with(mock_session, qa_catalog.id)
159155
mock_find_data_source_config.assert_called_once_with(
@@ -194,7 +190,7 @@ async def test_generate_catalog_task_temp_data_source_config_not_found(
194190
mock_find_qa_catalog.return_value = qa_catalog
195191
mock_find_data_source_config.return_value = None
196192

197-
await generate_catalog(qa_catalog.id, catalog_generation_data)
193+
await generate_catalog(mock_session, qa_catalog.id, catalog_generation_data)
198194

199195
mock_find_qa_catalog.assert_called_once_with(mock_session, qa_catalog.id)
200196
mock_find_data_source_config.assert_called_once_with(

backend/tests/unit/qa_catalog/generator/implementation/ragas/test_generator.py

Lines changed: 82 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
from pathlib import Path
33
from random import shuffle
44
from typing import Callable, ContextManager, Generator
5-
from unittest.mock import AsyncMock, MagicMock, patch
5+
from unittest.mock import MagicMock, patch
66

77
import pytest
8-
from anyio import CapacityLimiter
9-
from anyio.streams.memory import MemoryObjectSendStream
108
from langchain_core.documents import Document
119
from ragas.testset.graph import Node, NodeType
1210

@@ -234,31 +232,27 @@ async def test_ragas_generator_has_same_nodes(
234232

235233

236234
@pytest.mark.asyncio
237-
@patch("llm_eval.qa_catalog.generator.implementation.ragas.generator.create_backup")
238235
@patch("llm_eval.qa_catalog.generator.implementation.ragas.generator.KnowledgeGraph")
239236
async def test_ragas_generator_create_knowledge_graph_new_successfully(
240237
mock_knowledge_graph: MagicMock,
241-
mock_create_backup: MagicMock,
242238
ragas_generator_factory: RagasGeneratorFactory,
243239
documents: list[Document],
244240
nodes: list[Node],
245241
) -> None:
246242
with ragas_generator_factory() as generator:
247243
generator._load_and_process_documents = MagicMock(return_value=documents)
248244
generator._create_knowledge_graph_nodes = MagicMock(return_value=nodes)
249-
generator.load_exiting_knowledge_graph = MagicMock(return_value=None)
245+
generator.split_documents = MagicMock(return_value=documents)
250246
generator.apply_knowledge_graph_transformations = MagicMock()
251247

252248
created_kg = mock_knowledge_graph(nodes=nodes)
253249
mock_knowledge_graph.return_value = created_kg
254250

255251
kg = generator.create_knowledge_graph()
256252

253+
generator._load_and_process_documents.assert_called_once()
254+
generator.split_documents.assert_called_once_with(documents)
257255
generator._create_knowledge_graph_nodes.assert_called_once_with(documents)
258-
generator.load_exiting_knowledge_graph.assert_called_once()
259-
mock_create_backup.assert_called_once_with(
260-
generator.config.knowledge_graph_location
261-
)
262256
mock_knowledge_graph.assert_called_with(nodes=nodes)
263257
generator.apply_knowledge_graph_transformations.assert_called_once_with(
264258
created_kg
@@ -277,16 +271,16 @@ async def test_ragas_generator_create_knowledge_graph_use_existing_graph_when_no
277271
with ragas_generator_factory() as generator:
278272
generator._load_and_process_documents = MagicMock(return_value=documents)
279273
generator._create_knowledge_graph_nodes = MagicMock(return_value=nodes)
274+
generator.split_documents = MagicMock(return_value=documents)
280275
generator.config.knowledge_graph_location = None
281276
generator.apply_knowledge_graph_transformations = MagicMock(return_value=None)
282-
existing_graph = MagicMock(name="existing_graph", nodes=nodes)
283-
generator.load_exiting_knowledge_graph = MagicMock(return_value=existing_graph)
284-
mock_knowledge_graph.return_value = MagicMock(name="generated_graph")
277+
generated_graph = MagicMock(name="generated_graph")
278+
mock_knowledge_graph.return_value = generated_graph
285279

286280
kg: MagicMock = generator.create_knowledge_graph() # type: ignore
287281

288282
assert kg is not None
289-
assert kg == existing_graph
283+
assert kg == generated_graph
290284

291285

292286
@pytest.mark.asyncio
@@ -310,14 +304,24 @@ async def test_ragas_generator_load_knowledge_graph(
310304
with ragas_generator_factory() as generator:
311305
generator.config.knowledge_graph_location = MagicMock()
312306
generator.config.knowledge_graph_location.exists = MagicMock(return_value=True)
313-
mock_knowledge_graph.load.return_value = MagicMock(name="existing_graph")
307+
loaded_graph = MagicMock(name="existing_graph")
308+
mock_knowledge_graph.load.return_value = loaded_graph
309+
310+
# Create proper mock documents with required attributes
311+
mock_doc = MagicMock()
312+
mock_doc.page_content = "test content"
313+
mock_doc.metadata = {"source": "test.txt"}
314+
docs = [mock_doc]
314315

315-
kg = generator.load_exiting_knowledge_graph()
316+
# Mock _has_same_nodes to return True so the loaded graph is returned
317+
generator._has_same_nodes = MagicMock(return_value=True)
318+
319+
kg = generator.load_exiting_knowledge_graph(docs)
316320

317321
mock_knowledge_graph.load.assert_called_once_with(
318322
generator.config.knowledge_graph_location
319323
)
320-
assert kg is not None
324+
assert kg is loaded_graph
321325

322326

323327
@pytest.mark.asyncio
@@ -330,7 +334,8 @@ async def test_ragas_generator_load_knowledge_graph_fails_on_non_existent_file(
330334
generator.config.knowledge_graph_location = MagicMock()
331335
generator.config.knowledge_graph_location.exists = MagicMock(return_value=False)
332336

333-
kg = generator.load_exiting_knowledge_graph()
337+
docs = [MagicMock()] # Mock documents
338+
kg = generator.load_exiting_knowledge_graph(docs)
334339

335340
assert kg is None
336341

@@ -342,7 +347,8 @@ async def test_ragas_generator_load_knowledge_graph_fails_on_none_location(
342347
with ragas_generator_factory() as generator:
343348
generator.config.knowledge_graph_location = None
344349

345-
kg = generator.load_exiting_knowledge_graph()
350+
docs = [MagicMock()] # Mock documents
351+
kg = generator.load_exiting_knowledge_graph(docs)
346352

347353
assert kg is None
348354

@@ -362,7 +368,13 @@ async def test_ragas_generator_load_knowledge_graph_fails_on_load(
362368
generator.config.knowledge_graph_location.exists = MagicMock(return_value=True)
363369
mock_knowledge_graph.load.side_effect = RuntimeError("Load failed")
364370

365-
kg = generator.load_exiting_knowledge_graph()
371+
# Create proper mock documents with required attributes
372+
mock_doc = MagicMock()
373+
mock_doc.page_content = "test content"
374+
mock_doc.metadata = {"source": "test.txt"}
375+
docs = [mock_doc]
376+
377+
kg = generator.load_exiting_knowledge_graph(docs)
366378

367379
mock_knowledge_graph.load.assert_called_once_with(
368380
generator.config.knowledge_graph_location
@@ -382,117 +394,72 @@ async def test_ragas_generator_generate_testset(
382394

383395
mock_testset_generator = MagicMock()
384396
mock_testset_generator.generate.return_value = ["generated query"]
385-
testset = generator.generate_testset(mock_testset_generator, 1)
386-
387-
assert testset[0] == "generated query" # type: ignore
388-
389-
390-
@pytest.mark.asyncio
391-
@patch(
392-
"llm_eval.qa_catalog.generator.implementation.ragas.generator.ragas_sample_to_synthetic_qa_pair",
393-
)
394-
async def test_ragas_generator_generate_single_sample(
395-
mock_from_ragas: MagicMock,
396-
ragas_generator_factory: RagasGeneratorFactory,
397-
) -> None:
398-
with ragas_generator_factory() as generator:
399-
generator.generate_testset = MagicMock(return_value=MagicMock(samples=[1]))
400-
mock_from_ragas.return_value = "sample"
401-
402-
send_sample = MagicMock()
403-
send_sample.send = AsyncMock()
404-
limiter = MagicMock()
405-
406-
await generator._generate_samples(MagicMock(), send_sample, limiter)
407397

408-
send_sample.send.assert_called_once_with("sample")
398+
# Provide a non-empty query distribution
399+
mock_synthesizer = MagicMock()
400+
query_distribution = [(mock_synthesizer, 1.0)]
409401

410-
411-
@pytest.mark.asyncio
412-
async def test_ragas_generator_generate_single_sample_wont_send_when_fails(
413-
ragas_generator_factory: RagasGeneratorFactory,
414-
) -> None:
415-
with ragas_generator_factory() as generator:
416-
generator.generate_testset = MagicMock(
417-
side_effect=RuntimeError("sample generation failed")
402+
testset = generator.generate_testset(
403+
mock_testset_generator, 1, query_distribution
418404
)
419405

420-
send_sample = MagicMock()
421-
send_sample.send = AsyncMock()
422-
limiter = MagicMock()
423-
424-
await generator._generate_samples(MagicMock(), send_sample, limiter)
425-
426-
send_sample.send.assert_not_called()
427-
428-
429-
@pytest.mark.asyncio
430-
async def test_ragas_generator_generate_single_sample_does_nothing_on_empty_generated_testset( # noqa: E501
431-
ragas_generator_factory: RagasGeneratorFactory,
432-
) -> None:
433-
with ragas_generator_factory() as generator:
434-
generator.generate_testset = MagicMock(return_value=MagicMock(samples=[]))
435-
436-
send_sample = MagicMock()
437-
send_sample.send = AsyncMock()
438-
limiter = MagicMock()
439-
440-
await generator._generate_samples(MagicMock(), send_sample, limiter)
441-
442-
send_sample.send.assert_not_called()
406+
assert testset[0] == "generated query" # type: ignore
443407

444408

445409
@pytest.mark.asyncio
446410
@patch(
447-
"llm_eval.qa_catalog.generator.implementation.ragas.generator.TestsetGenerator",
448-
new_callable=AsyncMock,
449-
)
450-
@patch(
451-
"llm_eval.qa_catalog.generator.implementation.ragas.generator.RagasQACatalogGenerator.load_chat_model",
452-
new_callable=AsyncMock,
453-
)
454-
@patch(
455-
"llm_eval.qa_catalog.generator.implementation.ragas.generator.copy",
456-
new_callable=AsyncMock,
411+
"llm_eval.qa_catalog.generator.implementation.ragas.generator.ragas_sample_to_synthetic_qa_pair"
457412
)
413+
@patch("ragas.testset.persona.generate_personas_from_kg")
458414
async def test_ragas_generator_a_create_synthetic_qa_successfull(
459-
mock_copy: AsyncMock,
460-
mock_load_chat_model: AsyncMock,
461-
mock_testset_generator: AsyncMock,
462-
config: RagasQACatalogGeneratorConfig,
463-
ragas_model_config: RagasQACatalogGeneratorModelConfig,
464-
data_source_config: QACatalogGeneratorDataSourceConfig,
415+
mock_generate_personas: MagicMock,
416+
mock_ragas_sample_to_qa_pair: MagicMock,
417+
ragas_generator_factory: RagasGeneratorFactory,
465418
) -> None:
466-
generator = RagasQACatalogGenerator(config, data_source_config, ragas_model_config)
467-
generator.create_knowledge_graph = AsyncMock()
468-
469-
sample = SyntheticQAPair(
470-
id="1",
471-
question="question",
472-
expected_output="expected_output",
473-
contexts=[],
474-
meta_data={},
475-
)
476-
477-
async def mock_generate(
478-
_, # noqa: ANN001
479-
send_sample: MemoryObjectSendStream[SyntheticQAPair | None],
480-
limiter: CapacityLimiter,
481-
) -> None:
482-
async with limiter:
483-
async with send_sample:
484-
await send_sample.send(sample)
419+
with ragas_generator_factory() as generator:
420+
# Mock knowledge graph
421+
mock_kg = MagicMock()
422+
mock_kg.nodes = []
423+
generator.create_knowledge_graph = MagicMock(return_value=mock_kg)
424+
425+
# Mock personas
426+
mock_persona = MagicMock()
427+
mock_generate_personas.return_value = [mock_persona]
428+
429+
# Mock query distribution
430+
mock_synthesizer = MagicMock()
431+
generator.create_query_distribution = MagicMock(
432+
return_value=[(mock_synthesizer, 1.0)]
433+
)
485434

486-
generator._generate_samples = AsyncMock(side_effect=mock_generate)
435+
# Mock testset generation
436+
mock_testset = MagicMock()
437+
mock_sample = MagicMock()
438+
mock_testset.samples = [mock_sample]
439+
generator.generate_testset = MagicMock(return_value=mock_testset)
440+
441+
# Mock conversion to SyntheticQAPair
442+
synthetic_qa_pair = SyntheticQAPair(
443+
id="1",
444+
question="test question",
445+
expected_output="test answer",
446+
contexts=[],
447+
meta_data={},
448+
)
449+
mock_ragas_sample_to_qa_pair.return_value = synthetic_qa_pair
487450

488-
samples = []
451+
# Collect samples
452+
collected_samples = []
489453

490-
async def process_samples_fn(samples_batch: list[SyntheticQAPair]) -> None:
491-
samples.extend(samples_batch)
454+
async def collect_samples_fn(samples: list[SyntheticQAPair]) -> None:
455+
collected_samples.extend(samples)
492456

493-
await generator.a_create_synthetic_qa(process_samples_fn)
457+
# Test the method
458+
await generator.a_create_synthetic_qa(collect_samples_fn)
494459

495-
assert len(samples) == generator.config.sample_count
460+
# Verify results
461+
assert len(collected_samples) == 1
462+
assert collected_samples[0] == synthetic_qa_pair
496463

497464

498465
@pytest.mark.asyncio

0 commit comments

Comments
 (0)