Skip to content

Commit 1c7339c

Browse files
committed
Add UoW and adapters to Lliam
1 parent 8aaa767 commit 1c7339c

File tree

2 files changed

+216
-0
lines changed

2 files changed

+216
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import abc
2+
from itertools import islice
3+
from pathlib import Path
4+
from typing import Any, Generator, Iterable, List, Tuple
5+
6+
from aws_doc_sdk_examples_tools.doc_gen import DocGen, Example
7+
from aws_doc_sdk_examples_tools.fs import PathFs
8+
from aws_doc_sdk_examples_tools.lliam.domain.model import Prompt
9+
from aws_doc_sdk_examples_tools.lliam.shared_constants import BATCH_PREFIX
10+
11+
DEFAULT_METADATA_PREFIX = "DEFAULT"
12+
DEFAULT_BATCH_SIZE = 150
13+
IAM_POLICY_LANGUAGE = "IAMPolicyGrammar"
14+
15+
16+
def batched(iterable: Iterable, n: int) -> Generator[Tuple, Any, None]:
17+
"Batch data into tuples of length n. The last batch may be shorter."
18+
# batched('ABCDEFG', 3) --> ABC DEF G
19+
if n < 1:
20+
raise ValueError("n must be at least one")
21+
it = iter(iterable)
22+
while batch := tuple(islice(it, n)):
23+
yield batch
24+
25+
26+
class AbstractPromptRepository(abc.ABC):
27+
def __init__(self):
28+
self.to_write = {}
29+
30+
def add(self, prompt: Prompt):
31+
self._add(prompt)
32+
33+
def all_all(self, prompts: List[Prompt]):
34+
for prompt in prompts:
35+
self._add(prompt)
36+
37+
def batch(self, prompts: List[Prompt]):
38+
self._batch(prompts)
39+
40+
def commit(self):
41+
self._commit()
42+
43+
def get(self, id: str) -> Prompt:
44+
prompt = self._get(id)
45+
return prompt
46+
47+
def get_all(self, ids: List[str]) -> List[Prompt]:
48+
prompts = []
49+
for id in ids:
50+
prompt = self._get(id)
51+
prompts.append(prompt)
52+
return prompts
53+
54+
def set_partition(self, name: str):
55+
self.partition_name = name
56+
57+
@property
58+
def partition(self):
59+
return self.partition_name or ""
60+
61+
@abc.abstractmethod
62+
def _add(self, product: Prompt):
63+
raise NotImplementedError
64+
65+
@abc.abstractmethod
66+
def _batch(self, prompts: List[Prompt]):
67+
raise NotImplementedError
68+
69+
@abc.abstractmethod
70+
def _commit(self):
71+
raise NotImplementedError
72+
73+
@abc.abstractmethod
74+
def _get(self, id: str) -> Prompt:
75+
raise NotImplementedError
76+
77+
78+
class FsPromptRepository(AbstractPromptRepository):
79+
def __init__(self):
80+
super().__init__()
81+
self.fs = PathFs()
82+
83+
def rollback(self):
84+
# TODO: This is not what rollback is for. We should be rolling back any
85+
# file changes
86+
self.to_write = {}
87+
88+
def _add(self, prompt: Prompt):
89+
self.to_write[prompt.id] = prompt.content
90+
91+
def _batch(self, prompts: List[Prompt]):
92+
for batch_num, batch in enumerate(batched(prompts, DEFAULT_BATCH_SIZE)):
93+
batch_name = f"{BATCH_PREFIX}{(batch_num + 1):03}"
94+
for prompt in batch:
95+
prompt.id = f"{batch_name}/{prompt.id}"
96+
self._add(prompt)
97+
98+
def _commit(self):
99+
base_path = (
100+
Path(self.partition) if self.partition else Path(".")
101+
)
102+
103+
for file_path, content in self.to_write.items():
104+
if content:
105+
full_path = base_path / file_path
106+
self.fs.mkdir(full_path.parent)
107+
self.fs.write(full_path, content)
108+
109+
def _get(self, id: str):
110+
return Prompt(id, self.fs.read(Path(id)))
111+
112+
113+
class AbstractDocGenRepository(abc.ABC):
114+
def get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
115+
return self._get_new_prompts(doc_gen_root)
116+
117+
@abc.abstractmethod
118+
def _get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
119+
raise NotImplementedError
120+
121+
122+
class FsDocGenRepository(AbstractDocGenRepository):
123+
def rollback(self):
124+
# TODO: This is not what rollback is for. We should be rolling back any
125+
# file changes
126+
self._doc_gen = None
127+
128+
def _get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
129+
self._doc_gen = DocGen.from_root(Path(doc_gen_root))
130+
self._doc_gen.collect_snippets()
131+
new_examples = self._get_new_examples()
132+
prompts = self._examples_to_prompts(new_examples)
133+
return prompts
134+
135+
def _get_new_examples(self) -> List[Tuple[str, Example]]:
136+
examples = self._doc_gen.examples
137+
138+
filtered_examples: List[Tuple[str, Example]] = []
139+
for example_id, example in examples.items():
140+
# TCXContentAnalyzer prefixes new metadata title/title_abbrev entries with
141+
# the DEFAULT_METADATA_PREFIX. Checking this here to make sure we're only
142+
# running the LLM tool on new extractions.
143+
title = example.title or ""
144+
title_abbrev = example.title_abbrev or ""
145+
if title.startswith(DEFAULT_METADATA_PREFIX) and title_abbrev.startswith(
146+
DEFAULT_METADATA_PREFIX
147+
):
148+
filtered_examples.append((example_id, example))
149+
return filtered_examples
150+
151+
def _examples_to_prompts(self, examples: List[Tuple[str, Example]]) -> List[Prompt]:
152+
snippets = self._doc_gen.snippets
153+
prompts = []
154+
for example_id, example in examples:
155+
key = (
156+
example.languages[IAM_POLICY_LANGUAGE]
157+
.versions[0]
158+
.excerpts[0]
159+
.snippet_files[0]
160+
.replace("/", ".")
161+
)
162+
snippet = snippets.get(key)
163+
prompts.append(Prompt(f"{example_id}.md", snippet.code))
164+
return prompts
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import abc
2+
3+
from aws_doc_sdk_examples_tools.lliam.adapters.repository import (
4+
AbstractPromptRepository,
5+
AbstractDocGenRepository,
6+
FsPromptRepository,
7+
FsDocGenRepository,
8+
)
9+
10+
11+
class AbstractUnitOfWork(abc.ABC):
12+
prompts: AbstractPromptRepository
13+
doc_gen: AbstractDocGenRepository
14+
15+
def __enter__(self) -> "AbstractUnitOfWork":
16+
return self
17+
18+
def __exit__(self, *args):
19+
self.rollback()
20+
21+
def commit(self):
22+
self._commit()
23+
24+
def collect_new_prompts(self):
25+
for prompt in self.prompts.seen:
26+
yield prompt
27+
28+
@abc.abstractmethod
29+
def _commit(self):
30+
raise NotImplementedError
31+
32+
@abc.abstractmethod
33+
def rollback(self):
34+
raise NotImplementedError
35+
36+
37+
class FsUnitOfWork(AbstractUnitOfWork):
38+
39+
def __enter__(self):
40+
self.prompts = FsPromptRepository()
41+
self.doc_gen = FsDocGenRepository()
42+
return super().__enter__()
43+
44+
def __exit__(self, *args):
45+
super().__exit__(*args)
46+
47+
def _commit(self):
48+
self.prompts.commit()
49+
50+
def rollback(self):
51+
self.prompts.rollback()
52+
self.doc_gen.rollback()

0 commit comments

Comments
 (0)