Skip to content

Commit 31c5a64

Browse files
feat: add config and operator node types
1 parent 761e64b commit 31c5a64

File tree

4 files changed

+38
-60
lines changed

4 files changed

+38
-60
lines changed

graphgen/bases/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
StorageNameSpace,
1414
)
1515
from .base_tokenizer import BaseTokenizer
16-
from .datatypes import Chunk, QAPair, Token
16+
from .datatypes import Chunk, Config, Node, QAPair, Token

graphgen/bases/datatypes.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from dataclasses import dataclass, field
33
from typing import List, Union
44

5+
from pydantic import BaseModel, Field, field_validator
6+
57

68
@dataclass
79
class Chunk:
@@ -48,3 +50,38 @@ class Community:
4850
nodes: List[str] = field(default_factory=list)
4951
edges: List[tuple] = field(default_factory=list)
5052
metadata: dict = field(default_factory=dict)
53+
54+
55+
class Node(BaseModel):
56+
id: str = Field(..., description="unique node id")
57+
op_name: str = Field(..., description="operator name")
58+
type: str = Field(
59+
..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch"
60+
)
61+
params: dict = Field(default_factory=dict, description="operator parameters")
62+
dependencies: List[str] = Field(
63+
default_factory=list, description="list of dependent node ids"
64+
)
65+
66+
@classmethod
67+
@field_validator("type")
68+
def validate_type(cls, v: str) -> str:
69+
valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"}
70+
if v not in valid_types:
71+
raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.")
72+
return v
73+
74+
75+
class Config(BaseModel):
76+
nodes: List[Node] = Field(
77+
..., min_length=1, description="list of nodes in the computation graph"
78+
)
79+
80+
@classmethod
81+
@field_validator("nodes")
82+
def validate_unique_ids(cls, v: List[Node]) -> List[Node]:
83+
ids = [node.id for node in v]
84+
if len(ids) != len(set(ids)):
85+
duplicates = {id_ for id_ in ids if ids.count(id_) > 1}
86+
raise ValueError(f"Duplicate node ids found: {duplicates}")
87+
return v

graphgen/operators/storage.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)