@@ -114,78 +114,87 @@ def validate_node_name(value: str) -> str:
114114class Artifacts (BaseModel ):
115115 """Container for storing and managing artifacts generated by pipeline nodes.
116116
117- Modules hyperparams and outputs. The best ones are transmitted between nodes of the pipeline.
117+ Only stores the best artifact for each node type to optimize memory usage.
118+ The best ones are transmitted between nodes of the pipeline.
118119
119120 Attributes:
120- regex: List of artifacts from the regex node.
121- embedding: List of artifacts from the embedding node.
122- scoring: List of artifacts from the scoring node.
123- decision: List of artifacts from the decision node.
121+ regex: Best artifact from the regex node.
122+ embedding: Best artifact from the embedding node.
123+ scoring: Best artifact from the scoring node.
124+ decision: Best artifact from the decision node.
124125 """
125126
126127 model_config = ConfigDict (arbitrary_types_allowed = True )
127128
128- regex : list [ RegexArtifact ] = []
129- embedding : list [ EmbeddingArtifact ] = []
130- scoring : list [ ScorerArtifact ] = []
131- decision : list [ DecisionArtifact ] = []
129+ regex : RegexArtifact | None = None
130+ embedding : EmbeddingArtifact | None = None
131+ scoring : ScorerArtifact | None = None
132+ decision : DecisionArtifact | None = None
132133
133134 def model_dump (self , ** kwargs : Any ) -> dict [str , Any ]: # noqa: ANN401
134135 """Convert the model to a dictionary, ensuring nested artifacts are properly serialized."""
135136 data = super ().model_dump (** kwargs )
136137 for node_type in [NodeType .regex , NodeType .embedding , NodeType .scoring , NodeType .decision ]:
137- artifacts = getattr (self , node_type .value )
138- data [node_type .value ] = [artifact .model_dump (** kwargs ) for artifact in artifacts ]
138+ artifact = getattr (self , node_type .value )
139+ if artifact is not None :
140+ data [node_type .value ] = artifact .model_dump (** kwargs )
141+ else :
142+ data [node_type .value ] = None
139143 return data
140144
141145 @classmethod
142146 def model_validate (cls , obj : dict [str , Any ]) -> "Artifacts" :
143147 """Convert the dictionary back to an Artifacts instance, ensuring nested artifacts are properly deserialized."""
144148 # First convert the lists back to numpy arrays in the scoring artifacts
145- if "scoring" in obj :
146- for artifact in obj ["scoring" ]:
147- if artifact .get ("train_scores" ) is not None :
148- artifact ["train_scores" ] = np .array (artifact ["train_scores" ])
149- if artifact .get ("validation_scores" ) is not None :
150- artifact ["validation_scores" ] = np .array (artifact ["validation_scores" ])
151- if artifact .get ("test_scores" ) is not None :
152- artifact ["test_scores" ] = np .array (artifact ["test_scores" ])
153- if artifact .get ("folded_scores" ) is not None :
154- artifact ["folded_scores" ] = [np .array (arr ) for arr in artifact ["folded_scores" ]]
149+ if "scoring" in obj and obj ["scoring" ] is not None :
150+ if obj ["scoring" ].get ("train_scores" ) is not None :
151+ obj ["scoring" ]["train_scores" ] = np .array (obj ["scoring" ]["train_scores" ])
152+ if obj ["scoring" ].get ("validation_scores" ) is not None :
153+ obj ["scoring" ]["validation_scores" ] = np .array (obj ["scoring" ]["validation_scores" ])
154+ if obj ["scoring" ].get ("test_scores" ) is not None :
155+ obj ["scoring" ]["test_scores" ] = np .array (obj ["scoring" ]["test_scores" ])
156+ if obj ["scoring" ].get ("folded_scores" ) is not None :
157+ obj ["scoring" ]["folded_scores" ] = [np .array (arr ) for arr in obj ["scoring" ]["folded_scores" ]]
155158
156159 return super ().model_validate (obj )
157160
158161 def add_artifact (self , node_type : str , artifact : Artifact ) -> None :
159- """Add an artifact to the specified node type.
162+ """Add an artifact to the specified node type, replacing any existing artifact .
160163
161164 Args:
162165 node_type: Node type as a string.
163166 artifact: The artifact to add.
164167 """
165- self . get_artifacts (node_type ). append ( artifact )
168+ setattr ( self , validate_node_name (node_type ), artifact )
166169
167- def get_artifacts (self , node_type : str ) -> list [ Artifact ] :
168- """Retrieve all artifacts for a specified node type.
170+ def get_artifact (self , node_type : str ) -> Artifact | None :
171+ """Retrieve the artifact for a specified node type.
169172
170173 Args:
171174 node_type: Node type as a string.
172175
173176 Returns:
174- A list of artifacts for the node type.
177+ The artifact for the node type, or None if no artifact exists .
175178 """
176179 return getattr (self , validate_node_name (node_type )) # type: ignore[no-any-return]
177180
178- def get_best_artifact (self , node_type : str , idx : int ) -> Artifact :
179- """Retrieve the best artifact for a specified node type and index .
181+ def get_best_artifact (self , node_type : str ) -> Artifact :
182+ """Retrieve the artifact for a specified node type.
180183
181184 Args:
182185 node_type: Node type as a string.
183- idx: Index of the artifact.
184186
185187 Returns:
186- The best artifact.
188+ The artifact for the node type.
189+
190+ Raises:
191+ ValueError: If no artifact exists for the node type.
187192 """
188- return self .get_artifacts (node_type )[idx ]
193+ artifact = self .get_artifact (node_type )
194+ if artifact is None :
195+ msg = f"No artifact for { node_type } "
196+ raise ValueError (msg )
197+ return artifact
189198
190199 def has_artifacts (self ) -> bool :
191200 """Check if any artifacts have been saved in RAM.
@@ -194,7 +203,7 @@ def has_artifacts(self) -> bool:
194203 True if any artifacts exist, False otherwise.
195204 """
196205 node_types = [NodeType .regex , NodeType .embedding , NodeType .scoring , NodeType .decision ]
197- return any (len ( self .get_artifacts (nt )) > 0 for nt in node_types )
206+ return any (self .get_artifact (nt ) is not None for nt in node_types )
198207
199208
200209class Trial (BaseModel ):
@@ -263,39 +272,3 @@ def add_trial(self, node_type: str, trial: Trial) -> None:
263272 trial: The trial to add.
264273 """
265274 self .get_trials (node_type ).append (trial )
266-
267-
268- class TrialsIds (BaseModel ):
269- """Representation of the best trial IDs for each pipeline node.
270-
271- Attributes:
272- regex: Best trial index for the regex node.
273- embedding: Best trial index for the embedding node.
274- scoring: Best trial index for the scoring node.
275- decision: Best trial index for the decision node.
276- """
277-
278- regex : int | None = None
279- embedding : int | None = None
280- scoring : int | None = None
281- decision : int | None = None
282-
283- def get_best_trial_idx (self , node_type : str ) -> int | None :
284- """Retrieve the best trial index for a specified node type.
285-
286- Args:
287- node_type: Node type as a string.
288-
289- Returns:
290- The index of the best trial, or None if not set.
291- """
292- return getattr (self , validate_node_name (node_type )) # type: ignore[no-any-return]
293-
294- def set_best_trial_idx (self , node_type : str , idx : int ) -> None :
295- """Set the best trial index for a specified node type.
296-
297- Args:
298- node_type: Node type as a string.
299- idx: Index of the best trial.
300- """
301- setattr (self , validate_node_name (node_type ), idx )
0 commit comments