11import abc
22from itertools import islice
33from pathlib import Path
4- from typing import Any , Generator , Iterable , List , Tuple
4+ from typing import Any , Dict , Generator , Iterable , List , Tuple
55
66from aws_doc_sdk_examples_tools .doc_gen import DocGen , Example
7- from aws_doc_sdk_examples_tools .fs import Fs , PathFs
7+ from aws_doc_sdk_examples_tools .fs import Fs , PathFs
88from aws_doc_sdk_examples_tools .lliam .domain .model import Prompt
99from aws_doc_sdk_examples_tools .lliam .shared_constants import BATCH_PREFIX
1010
@@ -23,113 +23,68 @@ def batched(iterable: Iterable, n: int) -> Generator[Tuple, Any, None]:
2323 yield batch
2424
2525
26- class AbstractPromptRepository (abc .ABC ):
27- def __init__ (self ):
28- self .to_write = {}
29-
30- def add (self , prompt : Prompt ):
31- self ._add (prompt )
32-
33- def all_all (self , prompts : List [Prompt ]):
34- for prompt in prompts :
35- self ._add (prompt )
36-
37- def batch (self , prompts : List [Prompt ]):
38- self ._batch (prompts )
39-
40- def commit (self ):
41- self ._commit ()
42-
43- def get (self , id : str ) -> Prompt :
44- prompt = self ._get (id )
45- return prompt
46-
47- def get_all (self , ids : List [str ]) -> List [Prompt ]:
48- prompts = []
49- for id in ids :
50- prompt = self ._get (id )
51- prompts .append (prompt )
52- return prompts
53-
54- def set_partition (self , name : str ):
55- self .partition_name = name
56-
57- @property
58- def partition (self ):
59- return self .partition_name or ""
26+ class FsPromptRepository :
27+ to_write : Dict [str , str ] = {}
6028
61- @abc .abstractmethod
62- def _add (self , product : Prompt ):
63- raise NotImplementedError
64-
65- @abc .abstractmethod
66- def _batch (self , prompts : List [Prompt ]):
67- raise NotImplementedError
68-
69- @abc .abstractmethod
70- def _commit (self ):
71- raise NotImplementedError
72-
73- @abc .abstractmethod
74- def _get (self , id : str ) -> Prompt :
75- raise NotImplementedError
76-
77-
78- class FsPromptRepository (AbstractPromptRepository ):
7929 def __init__ (self , fs : Fs = PathFs ()):
80- super ().__init__ ()
8130 self .fs = fs
8231
8332 def rollback (self ):
8433 # TODO: This is not what rollback is for. We should be rolling back any
8534 # file changes
8635 self .to_write = {}
8736
88- def _add (self , prompt : Prompt ):
37+ def add (self , prompt : Prompt ):
8938 self .to_write [prompt .id ] = prompt .content
9039
91- def _batch (self , prompts : List [Prompt ]):
40+ def all_all (self , prompts : List [Prompt ]):
41+ for prompt in prompts :
42+ self .add (prompt )
43+
44+ def batch (self , prompts : List [Prompt ]):
9245 for batch_num , batch in enumerate (batched (prompts , DEFAULT_BATCH_SIZE )):
9346 batch_name = f"{ BATCH_PREFIX } { (batch_num + 1 ):03} "
9447 for prompt in batch :
9548 prompt .id = f"{ batch_name } /{ prompt .id } "
96- self ._add (prompt )
49+ self .add (prompt )
9750
98- def _commit (self ):
99- base_path = (
100- Path (self .partition ) if self .partition else Path ("." )
101- )
51+ def commit (self ):
52+ base_path = Path (self .partition ) if self .partition else Path ("." )
10253
10354 for file_path , content in self .to_write .items ():
10455 if content :
10556 full_path = base_path / file_path
10657 self .fs .mkdir (full_path .parent )
10758 self .fs .write (full_path , content )
10859
109- def _get (self , id : str ):
60+ def get (self , id : str ):
11061 return Prompt (id , self .fs .read (Path (id )))
11162
63+ def get_all (self , ids : List [str ]) -> List [Prompt ]:
64+ prompts = []
65+ for id in ids :
66+ prompt = self .get (id )
67+ prompts .append (prompt )
68+ return prompts
11269
113- class AbstractDocGenRepository (abc .ABC ):
114- def get_new_prompts (self , doc_gen_root : str ) -> List [Prompt ]:
115- return self ._get_new_prompts (doc_gen_root )
70+ def set_partition (self , name : str ):
71+ self .partition_name = name
11672
117- @abc . abstractmethod
118- def _get_new_prompts (self , doc_gen_root : str ) -> List [ Prompt ] :
119- raise NotImplementedError
73+ @property
74+ def partition (self ) :
75+ return self . partition_name or ""
12076
12177
122- class FsDocGenRepository ( AbstractDocGenRepository ) :
78+ class FsDocGenRepository :
12379 def __init__ (self , fs : Fs = PathFs ()):
124- super ().__init__ ()
12580 self .fs = fs
126-
81+
12782 def rollback (self ):
12883 # TODO: This is not what rollback is for. We should be rolling back any
12984 # file changes
13085 self ._doc_gen = None
13186
132- def _get_new_prompts (self , doc_gen_root : str ) -> List [Prompt ]:
87+ def get_new_prompts (self , doc_gen_root : str ) -> List [Prompt ]:
13388 # Right now this is the only instance of DocGen used in this Repository,
13489 # but if that changes we need to move it up.
13590 self ._doc_gen = DocGen .from_root (Path (doc_gen_root ), fs = self .fs )
0 commit comments