Skip to content

Commit fcde659

Browse files
sebampueroSebastiansonegs0010aor
authored
feat: add AI-powered flashcard and collection generation
feat: add AI-powered flashcard and collection generation --------- Co-authored-by: Sebastian <[email protected]> Co-authored-by: Miguel Cobo <[email protected]> Co-authored-by: ZorroGuardaPavos <[email protected]>
1 parent b2d7b53 commit fcde659

File tree

22 files changed

+372
-231
lines changed

22 files changed

+372
-231
lines changed

backend/.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ AI_MODEL=
2323
AI_API_KEY=
2424

2525
COLLECTION_GENERATION_PROMPT="I want to generate flashcards on a specific topic for efficient studying. Please create a set of flashcards covering key concepts, definitions, important details, and examples, with a focus on progressively building understanding of the topic. The flashcards should aim to provide a helpful learning experience by using structured explanations, real-world examples and formatting. Each flashcard should follow this format: Front (Question/Prompt): A clear and concise question or term to test recall, starting with introductory concepts and moving toward more complex details. Back (Answer): If the front is a concept or topic, provide a detailed explanation, broken down into clear paragraphs with easy-to-understand language. If possible, include a real-world example, analogy or illustrative diagrams to make the concept more memorable and relatable. If the front is a vocabulary word (for language learning), provide a direct translation in the target language. Optional Hint: A short clue to aid recall, especially for more complex concepts. Important: Use valid Markdown format for the back of the flashcard."
26+
CARD_GENERATION_PROMPT="I want to generate a flashcard on a specific topic. The contents of the flashcard should provide helpful information that aim to help the learner retain the concepts given. The flashcard must follow this format: Front (Question/Prompt): A clear and concise question or term to test recall. Back (Answer): If the front is a concept or topic, provide a detailed explanation, broken down into clear paragraphs with easy-to-understand language. If possible, include a real-world example, analogy or illustrative diagrams to make the concept more memorable and relatable. If the front is a vocabulary word (for language learning), provide a direct translation in the target language. Important: Use valid Markdown format for the back of the flashcard."

backend/src/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
6666
AI_MODEL: str | None = None
6767

6868
COLLECTION_GENERATION_PROMPT: str | None = None
69+
CARD_GENERATION_PROMPT: str | None = None
6970

7071
@computed_field # type: ignore[prop-decorator]
7172
@property

backend/src/flashcards/ai_config.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@
44
from src.core.config import settings
55

66

7+
def card_response_schema(schema_type):
8+
return schema_type.Schema(
9+
type=schema_type.Type.OBJECT,
10+
required=["front", "back"],
11+
properties={
12+
"front": schema_type.Schema(
13+
type=schema_type.Type.STRING,
14+
),
15+
"back": schema_type.Schema(
16+
type=schema_type.Type.STRING,
17+
),
18+
},
19+
)
20+
21+
722
def collection_response_schema(schema_type):
823
return schema_type.Schema(
924
type=schema_type.Type.OBJECT,
@@ -18,18 +33,7 @@ def collection_response_schema(schema_type):
1833
),
1934
"cards": schema_type.Schema(
2035
type=schema_type.Type.ARRAY,
21-
items=schema_type.Schema(
22-
type=schema_type.Type.OBJECT,
23-
required=["front", "back"],
24-
properties={
25-
"front": schema_type.Schema(
26-
type=schema_type.Type.STRING,
27-
),
28-
"back": schema_type.Schema(
29-
type=schema_type.Type.STRING,
30-
),
31-
},
32-
),
36+
items=card_response_schema(schema_type),
3337
),
3438
},
3539
),
@@ -42,3 +46,10 @@ def get_flashcard_config(schema_type) -> types.GenerateContentConfig:
4246
response_schema=collection_response_schema(schema_type),
4347
system_instruction=settings.COLLECTION_GENERATION_PROMPT,
4448
)
49+
50+
51+
def get_card_config(schema_type) -> types.GenerateContentConfig:
52+
return create_content_config(
53+
response_schema=card_response_schema(schema_type),
54+
system_instruction=settings.CARD_GENERATION_PROMPT,
55+
)

backend/src/flashcards/api.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import uuid
23
from typing import Any, Literal
34

@@ -46,25 +47,28 @@ async def create_collection(
4647
collection_in: CollectionCreate,
4748
provider: GeminiProviderDep,
4849
) -> Any:
50+
name = collection_in.name
51+
cards = None
52+
4953
if collection_in.prompt:
5054
try:
51-
collection = await services.generate_ai_collection(
52-
session=session,
53-
user_id=current_user.id,
54-
prompt=collection_in.prompt,
55-
provider=provider,
55+
flashcard_collection = await services.generate_ai_collection(
56+
provider, collection_in.prompt
5657
)
57-
return collection
58+
name = flashcard_collection.name
59+
cards = flashcard_collection.cards
5860
except EmptyCollectionError:
5961
raise HTTPException(
6062
status_code=400, detail="Failed to generate flashcards from the prompt"
6163
)
6264
except AIGenerationError as e:
6365
raise HTTPException(status_code=500, detail=str(e))
64-
else:
65-
return services.create_collection(
66-
session=session, collection_in=collection_in, user_id=current_user.id
66+
67+
return await asyncio.to_thread(
68+
lambda: services.create_collection(
69+
session=session, user_id=current_user.id, name=name, cards=cards
6770
)
71+
)
6872

6973

7074
@router.get("/collections/{collection_id}", response_model=Collection)
@@ -126,16 +130,28 @@ def read_cards(
126130

127131

128132
@router.post("/collections/{collection_id}/cards/", response_model=Card)
129-
def create_card(
133+
async def create_card(
130134
session: SessionDep,
131135
current_user: CurrentUser,
132136
collection_id: uuid.UUID,
133137
card_in: CardCreate,
138+
provider: GeminiProviderDep,
134139
) -> Any:
135-
if not services.check_collection_access(session, collection_id, current_user.id):
140+
access_checked = await asyncio.to_thread(
141+
lambda: services.check_collection_access(
142+
session, collection_id, current_user.id
143+
)
144+
)
145+
if not access_checked:
136146
raise HTTPException(status_code=404, detail="Collection not found")
137-
return services.create_card(
138-
session=session, collection_id=collection_id, card_in=card_in
147+
if card_in.prompt:
148+
card_base = await services.generate_ai_flashcard(card_in.prompt, provider)
149+
card_in.front = card_base.front
150+
card_in.back = card_base.back
151+
return await asyncio.to_thread(
152+
lambda: services.create_card(
153+
session=session, collection_id=collection_id, card_in=card_in
154+
)
139155
)
140156

141157

backend/src/flashcards/schemas.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
import uuid
22
from datetime import datetime
33

4-
from pydantic import BaseModel
5-
from pydantic import Field as PydanticField
6-
from sqlmodel import Field, SQLModel
4+
from pydantic import BaseModel, Field
5+
from sqlmodel import SQLModel
76

87

98
class CollectionBase(SQLModel):
109
name: str
1110

1211

1312
class CollectionCreate(CollectionBase):
14-
prompt: str | None = None
13+
prompt: str | None = Field(default=None, max_length=100)
1514

1615

1716
class CollectionUpdate(SQLModel):
@@ -24,7 +23,7 @@ class CardBase(SQLModel):
2423

2524

2625
class CardCreate(CardBase):
27-
pass
26+
prompt: str | None = Field(default=None, max_length=100)
2827

2928

3029
class CardUpdate(CardBase):
@@ -110,17 +109,13 @@ class PracticeCardResponse(SQLModel):
110109
is_correct: bool | None
111110

112111

113-
class AIFlashcardsRequest(SQLModel):
114-
prompt: str = Field(max_length=100)
115-
116-
117112
class AIFlashcard(BaseModel):
118113
front: str
119114
back: str
120115

121116

122117
class AIFlashcardCollection(BaseModel):
123-
name: str = PydanticField(description="the simple name of the topic")
118+
name: str
124119
cards: list[AIFlashcard]
125120

126121

backend/src/flashcards/services.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
from src.ai_models.gemini.exceptions import AIGenerationError
1212

13-
from .ai_config import get_flashcard_config
13+
from .ai_config import get_card_config, get_flashcard_config
1414
from .exceptions import EmptyCollectionError
1515
from .models import Card, Collection, PracticeCard, PracticeSession
1616
from .schemas import (
1717
AIFlashcardCollection,
18+
CardBase,
1819
CardCreate,
1920
CardUpdate,
20-
CollectionCreate,
2121
CollectionUpdate,
2222
)
2323

@@ -48,11 +48,22 @@ def get_collection(
4848

4949

5050
def create_collection(
51-
session: Session, collection_in: CollectionCreate, user_id: uuid.UUID
51+
session: Session, user_id: uuid.UUID, name: str, cards: list[CardBase] | None = None
5252
) -> Collection:
53-
collection = Collection.model_validate(collection_in, update={"user_id": user_id})
54-
collection.user_id = user_id
53+
collection = Collection(name=name, user_id=user_id)
5554
session.add(collection)
55+
session.flush()
56+
57+
if cards:
58+
card_objs = [
59+
Card(
60+
front=card.front,
61+
back=card.back,
62+
collection_id=collection.id,
63+
)
64+
for card in cards
65+
]
66+
session.add_all(card_objs)
5667
session.commit()
5768
session.refresh(collection)
5869
return collection
@@ -361,7 +372,7 @@ def get_card_by_id(session: Session, card_id: uuid.UUID) -> Card | None:
361372
return session.exec(statement).first()
362373

363374

364-
async def _generate_ai_flashcards(provider, prompt: str) -> AIFlashcardCollection:
375+
async def generate_ai_collection(provider, prompt: str) -> AIFlashcardCollection:
365376
content_config = get_flashcard_config(genai.types)
366377
raw_response = await provider.run_model(content_config, prompt)
367378

@@ -382,33 +393,18 @@ async def _generate_ai_flashcards(provider, prompt: str) -> AIFlashcardCollectio
382393
raise AIGenerationError(f"Error processing AI response: {str(e)}")
383394

384395

385-
def _save_ai_collection(
386-
session: Session, user_id: uuid.UUID, flashcard_collection: AIFlashcardCollection
387-
) -> Collection:
388-
collection = Collection(
389-
name=flashcard_collection.name,
390-
user_id=user_id,
391-
)
392-
session.add(collection)
393-
session.commit()
394-
session.refresh(collection)
395-
396-
for card_data in flashcard_collection.cards:
397-
card = Card(
398-
front=card_data.front,
399-
back=card_data.back,
400-
collection_id=collection.id,
401-
)
402-
session.add(card)
403-
404-
session.commit()
405-
session.refresh(collection)
406-
return collection
407-
408-
409-
async def generate_ai_collection(
410-
session: Session, user_id: uuid.UUID, prompt: str, provider
411-
) -> Collection:
412-
"""Generate a collection of flashcards using AI and save it to the database."""
413-
flashcard_collection = await _generate_ai_flashcards(provider, prompt)
414-
return _save_ai_collection(session, user_id, flashcard_collection)
396+
async def generate_ai_flashcard(prompt: str, provider) -> CardBase:
397+
content_config = get_card_config(genai.types)
398+
raw_response = await provider.run_model(content_config, prompt)
399+
try:
400+
json_data = json.loads(raw_response)
401+
if "front" not in json_data or "back" not in json_data:
402+
raise AIGenerationError("AI response missing 'front' or 'back' field")
403+
card = CardBase(front=json_data["front"], back=json_data["back"])
404+
return card
405+
except json.JSONDecodeError:
406+
raise AIGenerationError("Failed to parse AI response as JSON")
407+
except ValidationError as e:
408+
raise AIGenerationError(f"Invalid AI response format: {str(e)}")
409+
except Exception as e:
410+
raise AIGenerationError(f"Error processing AI response: {str(e)}")

backend/tests/flashcards/card/test_api.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22
from typing import Any
3+
from unittest.mock import ANY, AsyncMock, patch
34

45
import pytest
56
from fastapi.testclient import TestClient
@@ -380,3 +381,46 @@ def test_deleted_card_not_in_list(
380381
list_after = rsp.json()
381382
assert list_after["count"] == len(test_multiple_cards) - 1
382383
assert delete_card["id"] not in [card["id"] for card in list_after["data"]]
384+
385+
386+
def test_create_card_with_prompt_ai(
387+
client: TestClient,
388+
normal_user_token_headers: dict[str, str],
389+
test_collection: dict[str, Any],
390+
):
391+
collection_id = test_collection["id"]
392+
prompt = "What is a closure in Python?"
393+
ai_card = {"front": "What is a closure?", "back": "A closure is..."}
394+
with patch(
395+
"src.flashcards.services.generate_ai_flashcard", new_callable=AsyncMock
396+
) as mock_ai:
397+
mock_ai.return_value = type("Card", (), ai_card)()
398+
card_data = {"prompt": prompt, "front": "", "back": ""}
399+
rsp = client.post(
400+
f"{settings.API_V1_STR}/collections/{collection_id}/cards/",
401+
json=card_data,
402+
headers=normal_user_token_headers,
403+
)
404+
assert rsp.status_code == 200
405+
content = rsp.json()
406+
assert content["front"] == ai_card["front"]
407+
assert content["back"] == ai_card["back"]
408+
mock_ai.assert_called_once_with(prompt, ANY)
409+
410+
411+
def test_create_card_with_prompt_too_long(
412+
client: TestClient,
413+
normal_user_token_headers: dict[str, str],
414+
test_collection: dict[str, Any],
415+
):
416+
collection_id = test_collection["id"]
417+
prompt = "x" * 101
418+
card_data = {"prompt": prompt, "front": "", "back": ""}
419+
rsp = client.post(
420+
f"{settings.API_V1_STR}/collections/{collection_id}/cards/",
421+
json=card_data,
422+
headers=normal_user_token_headers,
423+
)
424+
assert rsp.status_code == 422
425+
content = rsp.json()
426+
assert "prompt" in str(content)

backend/tests/flashcards/collection/test_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,12 @@ def test_create_collection_with_prompt(
9494
assert rsp.status_code == 200
9595
content = rsp.json()
9696
assert content["name"] == collection_data.name
97-
assert content["id"] == str(mock_collection.id)
97+
assert "id" in content
98+
assert isinstance(content["id"], str)
99+
assert len(content["cards"]) == len(mock_collection.cards)
100+
for i, card in enumerate(mock_collection.cards):
101+
assert content["cards"][i]["front"] == card.front
102+
assert content["cards"][i]["back"] == card.back
98103

99104
mock_ai_generate.assert_called_once()
100105

backend/tests/flashcards/collection/test_services.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sqlmodel import Session
55

66
from src.flashcards.models import Collection
7-
from src.flashcards.schemas import CollectionCreate, CollectionUpdate
7+
from src.flashcards.schemas import CollectionUpdate
88
from src.flashcards.services import (
99
check_collection_access,
1010
create_collection,
@@ -16,15 +16,11 @@
1616

1717

1818
def test_create_collection(db: Session, test_user: dict[str, Any]):
19-
collection_in = CollectionCreate(name="Test Collection")
2019
collection = create_collection(
21-
session=db,
22-
collection_in=collection_in,
23-
user_id=test_user["id"],
20+
session=db, user_id=test_user["id"], name="Test Collection"
2421
)
25-
2622
assert collection.id is not None
27-
assert collection.name == collection_in.name
23+
assert collection.name == "Test Collection"
2824
assert collection.user_id == test_user["id"]
2925
assert collection.created_at is not None
3026
assert collection.updated_at is not None

0 commit comments

Comments
 (0)