|
2 | 2 | from dataclasses import dataclass, field |
3 | 3 | from typing import List, Union |
4 | 4 |
|
| 5 | +from pydantic import BaseModel, Field, field_validator |
| 6 | + |
5 | 7 |
|
6 | 8 | @dataclass |
7 | 9 | class Chunk: |
@@ -48,3 +50,38 @@ class Community: |
48 | 50 | nodes: List[str] = field(default_factory=list) |
49 | 51 | edges: List[tuple] = field(default_factory=list) |
50 | 52 | metadata: dict = field(default_factory=dict) |
| 53 | + |
| 54 | + |
| 55 | +class Node(BaseModel): |
| 56 | + id: str = Field(..., description="unique node id") |
| 57 | + op_name: str = Field(..., description="operator name") |
| 58 | + type: str = Field( |
| 59 | + ..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch" |
| 60 | + ) |
| 61 | + params: dict = Field(default_factory=dict, description="operator parameters") |
| 62 | + dependencies: List[str] = Field( |
| 63 | + default_factory=list, description="list of dependent node ids" |
| 64 | + ) |
| 65 | + |
| 66 | + @classmethod |
| 67 | + @field_validator("type") |
| 68 | + def validate_type(cls, v: str) -> str: |
| 69 | + valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"} |
| 70 | + if v not in valid_types: |
| 71 | + raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.") |
| 72 | + return v |
| 73 | + |
| 74 | + |
| 75 | +class Config(BaseModel): |
| 76 | + nodes: List[Node] = Field( |
| 77 | + ..., min_length=1, description="list of nodes in the computation graph" |
| 78 | + ) |
| 79 | + |
| 80 | + @classmethod |
| 81 | + @field_validator("nodes") |
| 82 | + def validate_unique_ids(cls, v: List[Node]) -> List[Node]: |
| 83 | + ids = [node.id for node in v] |
| 84 | + if len(ids) != len(set(ids)): |
| 85 | + duplicates = {id_ for id_ in ids if ids.count(id_) > 1} |
| 86 | + raise ValueError(f"Duplicate node ids found: {duplicates}") |
| 87 | + return v |
0 commit comments