Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions .cursor/rules/api-models.mdc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
---
description: Rules for Pydantic models and request/response validation
globs: ["app.py"]
alwaysApply: true
---

# API Models Guidelines

Pydantic models validate request bodies and ensure type safety. Models are defined directly in `app.py`.

## Model Definition

**Basic structure:**
```python
from pydantic import BaseModel
from typing import Union, Dict, Optional

class WeakSupervisionRequest(BaseModel):
project_id: str
labeling_task_id: str
user_id: str
weak_supervision_task_id: str
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None

class TaskStatsRequest(BaseModel):
project_id: str
labeling_task_id: str
user_id: str

class SourceStatsRequest(BaseModel):
project_id: str
source_id: str
user_id: str

class ExportWsStatsRequest(BaseModel):
project_id: str
labeling_task_id: str
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None
```

## Naming Conventions

- Request bodies: `WeakSupervisionRequest`, `TaskStatsRequest`, `SourceStatsRequest`
- Use descriptive names ending in `Request`
- Match the endpoint purpose (e.g., `WeakSupervisionRequest` for `/fit_predict`)

## Usage in Routes

```python
from app import WeakSupervisionRequest

@app.post("/fit_predict")
def weakly_supervise(request: WeakSupervisionRequest) -> responses.PlainTextResponse:
integration.fit_predict(
request.project_id,
request.labeling_task_id,
request.user_id,
request.weak_supervision_task_id,
request.overwrite_weak_supervision,
)
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
```

## Field Validation

```python
from pydantic import field_validator, Field

class WeakSupervisionRequest(BaseModel):
project_id: str = Field(..., min_length=1)
labeling_task_id: str = Field(..., min_length=1)
user_id: str = Field(..., min_length=1)
weak_supervision_task_id: str = Field(..., min_length=1)
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None

@field_validator('project_id', 'labeling_task_id', 'user_id', 'weak_supervision_task_id')
@classmethod
def validate_ids(cls, v):
if not v or not v.strip():
raise ValueError('ID cannot be empty')
return v.strip()
```

## Optional Fields with Complex Types

```python
from typing import Union, Dict, Optional

class WeakSupervisionRequest(BaseModel):
# Optional field that can be a float or a dict
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None
```

## Best Practices

1. Use standard Python types (`str`, `int`, `float`) - FastAPI handles validation
2. Use `Optional` for fields that may not be provided
3. Use `Union` for fields that can accept multiple types
4. Provide defaults for optional fields (use `None` for optional values)
5. Use descriptive model names ending in `Request`
6. Use `Field(...)` for required fields with constraints
7. Use field validators for custom validation logic
8. Keep models simple and focused on request structure
9. Define models in `app.py` near the routes that use them
238 changes: 238 additions & 0 deletions .cursor/rules/controllers.mdc
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
---
description: Rules for controller modules and business logic
globs: ["controller/**/*.py"]
alwaysApply: true
---

# Controllers Guidelines

Controllers contain business logic and orchestrate operations between routes, submodules, and external services for weak supervision operations.

## Import Patterns

```python
# Submodules
from submodules.model.business_objects import (
general,
labeling_task,
record_label_association,
weak_supervision,
labeling_task_label,
information_source,
notification,
project,
user,
)
from submodules.model import enums
from submodules.model.models import (
LabelingTask,
LabelingTaskLabel,
RecordLabelAssociation,
RecordLabelAssociationToken,
)

# External libraries
import weak_nlp
import pandas as pd
import traceback

# Controller utilities
from . import util
```

## Function Patterns

**Weak Supervision Integration:**
```python
def fit_predict(
project_id: str,
labeling_task_id: str,
user_id: str,
weak_supervision_task_id: str,
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None,
):
# Validate and prepare quality metrics
quality_metrics_overwrite = None
if overwrite_weak_supervision is not None:
quality_metrics_overwrite = __create_quality_metrics(
project_id, labeling_task_id, overwrite_weak_supervision
)

# Collect data
task_type, df = collect_data(project_id, labeling_task_id, True)
if len(df.index) == 0:
return

# Process based on task type
try:
if task_type == enums.LabelingTaskType.CLASSIFICATION.value:
results = integrate_classification(df, quality_metrics_overwrite)
else:
results = integrate_extraction(df, quality_metrics_overwrite)

# Store results
weak_supervision.store_data(
project_id,
labeling_task_id,
user_id,
results,
task_type,
weak_supervision_task_id,
with_commit=True,
)
except Exception:
print(traceback.format_exc(), flush=True)
general.rollback()
weak_supervision.update_state(
project_id,
weak_supervision_task_id,
enums.PayloadState.FAILED.value,
with_commit=True,
)
```

**Statistics Calculation:**
```python
def calculate_quality_statistics_for_labeling_task(
project_id: str, task_id: str, user_id: str
):
labeling_task_item = labeling_task.get_labeling_task_by_id_only(task_id)
_, df = integration.collect_data(
labeling_task_item.project_id, labeling_task_item.id, False
)
exclusion_ids = information_source.get_exclusion_record_ids_for_task(task_id)
df = df.loc[~df["record_id"].isin(exclusion_ids)]

try:
if labeling_task_item.task_type == enums.LabelingTaskType.CLASSIFICATION.value:
statistics = classification_quality(df)
else:
statistics = extraction_quality(df)

for source_id, statistics_item in statistics.items():
information_source.update_quality_stats(
labeling_task_item.project_id,
source_id,
statistics_item,
with_commit=True,
)
except weak_nlp.shared.exceptions.MissingReferenceException:
send_warning_no_reference_data(project_id, user_id)
```

## Error Handling Patterns

**Transaction Rollback:**
```python
try:
# Business logic
result = perform_operation()
weak_supervision.store_data(...)
except Exception:
print(traceback.format_exc(), flush=True)
general.rollback()
weak_supervision.update_state(
project_id,
weak_supervision_task_id,
enums.PayloadState.FAILED.value,
with_commit=True,
)
```

**Return Status Codes:**
```python
def export_weak_supervision_stats(
project_id: str,
labeling_task_id: str,
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]] = None,
) -> Tuple[int, str]:
try:
# Process statistics
# ...
return 200, "OK"
except Exception:
print(traceback.format_exc(), flush=True)
general.rollback()
return 500, "Internal server error"
```

## Data Collection Patterns

**Collecting Labeling Data:**
```python
def collect_data(
project_id: str, labeling_task_id: str, only_selected: bool
) -> Tuple[str, pd.DataFrame]:
labeling_task_item = labeling_task.get(project_id, labeling_task_id)

query_results = []
if labeling_task_item.task_type == enums.LabelingTaskType.CLASSIFICATION.value:
# Collect classification associations
for information_source_item in labeling_task_item.information_sources:
if only_selected and not information_source_item.is_selected:
continue
results = record_label_association.get_all_classifications_for_information_source(
project_id, information_source_item.id
)
query_results.extend(results)

request_body = __jsonize_classification_associations(query_results)
records_manual = record_label_association.get_manual_classifications_for_labeling_task_as_json(
project_id, labeling_task_id
)
request_body.extend(records_manual)

return labeling_task_item.task_type, pd.DataFrame(request_body).drop_duplicates()
```

## Integration with weak-nlp

```python
def integrate_classification(
df: pd.DataFrame,
quality_metrics_overwrite: Optional[Dict[Tuple[str, str], Dict[str, float]]] = None,
):
cnlm = util.get_cnlm_from_df(df)
weak_supervision_results = cnlm.weakly_supervise(quality_metrics_overwrite)
return_values = defaultdict(list)
for record_id, (label_id, confidence) in weak_supervision_results.dropna().items():
return_values[record_id].append(
{"label_id": label_id, "confidence": confidence}
)
return return_values
```

## Private Functions

Use double underscore prefix for internal helper functions:

```python
def __create_quality_metrics(
project_id: str,
labeling_task_id: str,
overwrite_weak_supervision: Union[float, Dict[str, float]],
) -> Dict[Tuple[str, str], Dict[str, float]]:
# Internal helper function
pass

def __jsonize_classification_associations(
association_tuples: Tuple[RecordLabelAssociation, LabelingTask, LabelingTaskLabel]
) -> List[Dict[str, Any]]:
# Internal helper function
pass
```

## Best Practices

1. Single responsibility per function
2. Always validate inputs and handle edge cases (empty DataFrames, missing data)
3. Use type hints for all parameters and return types
4. Use `with_commit` parameter appropriately for database operations
5. Use submodule business objects, never SQLAlchemy directly
6. Handle exceptions and rollback transactions on errors
7. Use `weak_nlp` library for weak supervision algorithms
8. Return tuples `(status_code, message)` for operations that need status reporting
9. Print tracebacks with `flush=True` for proper logging
10. Check for empty DataFrames before processing
11. Use double underscore prefix (`__`) for private/internal helper functions
12. Add comments explaining complex logic or edge cases
Loading