2424#
2525###############################################################################
2626import abc
27- import io
2827import os
2928from typing import Optional
3029
31- from pydantic import BaseModel , Field , field_validator
30+ from pydantic import BaseModel
3231
3332
3433class CommandArtifact (BaseModel ):
@@ -40,45 +39,56 @@ class CommandArtifact(BaseModel):
4039 exit_code : int
4140
4241
43- class FileArtifact (BaseModel ):
44- """Artifact to contains contents of file read into memory"""
45-
42+ class BaseFileArtifact (BaseModel , abc .ABC ):
4643 filename : str
47- contents : str | bytes = Field (exclude = True )
4844
49- @field_validator ("contents" , mode = "before" )
45+ @abc .abstractmethod
46+ def log_model (self , log_path : str ) -> None :
47+ pass
48+
49+ @abc .abstractmethod
50+ def contents_str (self ) -> str :
51+ pass
52+
5053 @classmethod
51- def validate_contents (cls , value : io .BytesIO | str | bytes ):
52- if isinstance (value , io .BytesIO ):
53- return value .getvalue ()
54- if isinstance (value , str ):
55- return value .encode ("utf-8" )
56- return value
54+ def from_bytes (
55+ cls ,
56+ filename : str ,
57+ raw_contents : bytes ,
58+ encoding : Optional [str ] = "utf-8" ,
59+ strip : bool = True ,
60+ ) -> "BaseFileArtifact" :
61+ if encoding is None :
62+ return BinaryFileArtifact (filename = filename , contents = raw_contents )
5763
58- def log_model (self , log_path : str , encoding : Optional [str ] = None ) -> None :
59- """Log the file contents to disk.
64+ try :
65+ text = raw_contents .decode (encoding )
66+ return TextFileArtifact (filename = filename , contents = text .strip () if strip else text )
67+ except UnicodeDecodeError :
68+ return BinaryFileArtifact (filename = filename , contents = raw_contents )
6069
61- Args:
62- log_path (str): path to write the file
63- encoding (str | None): if None, auto-detect binary or not
64- """
70+
71+ class TextFileArtifact (BaseFileArtifact ):
72+ contents : str
73+
74+ def log_model (self , log_path : str ) -> None :
75+ path = os .path .join (log_path , self .filename )
76+ with open (path , "w" , encoding = "utf-8" ) as f :
77+ f .write (self .contents )
78+
79+ def contents_str (self ) -> str :
80+ return self .contents
81+
82+
83+ class BinaryFileArtifact (BaseFileArtifact ):
84+ contents : bytes
85+
86+ def log_model (self , log_path : str ) -> None :
6587 log_name = os .path .join (log_path , self .filename )
66- contents = self .contents
67-
68- if encoding :
69- with open (log_name , "w" , encoding = encoding ) as f :
70- f .write (contents .decode (encoding ))
71- else :
72- try :
73- decoded = contents .decode ("utf-8" )
74- with open (log_name , "w" , encoding = "utf-8" ) as f :
75- f .write (decoded )
76- except UnicodeDecodeError :
77- with open (log_name , "wb" ) as f :
78- f .write (contents )
88+ with open (log_name , "wb" ) as f :
89+ f .write (self .contents )
7990
8091 def contents_str (self ) -> str :
81- """Safe string representation of contents (for logs)."""
8292 try :
8393 return self .contents .decode ("utf-8" )
8494 except UnicodeDecodeError :
@@ -104,14 +114,16 @@ def run_command(
104114 """
105115
106116 @abc .abstractmethod
107- def read_file (self , filename : str , encoding : str = "utf-8" , strip : bool = True ) -> FileArtifact :
108- """Read a file into a FileArtifact
117+ def read_file (
118+ self , filename : str , encoding : str = "utf-8" , strip : bool = True
119+ ) -> BaseFileArtifact :
120+ """Read a file into a BaseFileArtifact
109121
110122 Args:
111123 filename (str): filename
112124 encoding (str, optional): encoding to use when opening file. Defaults to "utf-8".
113125 strip (bool): automatically strip file contents
114126
115127 Returns:
116- FileArtifact : file artifact
128+ BaseFileArtifact : file artifact
117129 """
0 commit comments