Skip to content

Commit d30b8d2

Browse files
committed
Add UoW and adapters to Lliam
1 parent 8aaa767 commit d30b8d2

File tree

6 files changed

+338
-0
lines changed

6 files changed

+338
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import abc
2+
from itertools import islice
3+
from pathlib import Path
4+
from typing import Any, Generator, Iterable, List, Tuple
5+
6+
from aws_doc_sdk_examples_tools.doc_gen import DocGen, Example
7+
from aws_doc_sdk_examples_tools.fs import PathFs
8+
from aws_doc_sdk_examples_tools.lliam.domain.model import Prompt
9+
from aws_doc_sdk_examples_tools.lliam.shared_constants import BATCH_PREFIX
10+
11+
DEFAULT_METADATA_PREFIX = "DEFAULT"
12+
DEFAULT_BATCH_SIZE = 150
13+
IAM_POLICY_LANGUAGE = "IAMPolicyGrammar"
14+
15+
16+
def batched(iterable: Iterable, n: int) -> Generator[Tuple, Any, None]:
17+
"Batch data into tuples of length n. The last batch may be shorter."
18+
# batched('ABCDEFG', 3) --> ABC DEF G
19+
if n < 1:
20+
raise ValueError("n must be at least one")
21+
it = iter(iterable)
22+
while batch := tuple(islice(it, n)):
23+
yield batch
24+
25+
26+
class AbstractPromptRepository(abc.ABC):
27+
def __init__(self):
28+
self.to_write = {}
29+
30+
def add(self, prompt: Prompt):
31+
self._add(prompt)
32+
33+
def all_all(self, prompts: List[Prompt]):
34+
for prompt in prompts:
35+
self._add(prompt)
36+
37+
def batch(self, prompts: List[Prompt]):
38+
self._batch(prompts)
39+
40+
def commit(self):
41+
self._commit()
42+
43+
def get(self, id: str) -> Prompt:
44+
prompt = self._get(id)
45+
return prompt
46+
47+
def get_all(self, ids: List[str]) -> List[Prompt]:
48+
prompts = []
49+
for id in ids:
50+
prompt = self._get(id)
51+
prompts.append(prompt)
52+
return prompts
53+
54+
def set_partition(self, name: str):
55+
self.partition_name = name
56+
57+
@property
58+
def partition(self):
59+
return self.partition_name or ""
60+
61+
@abc.abstractmethod
62+
def _add(self, product: Prompt):
63+
raise NotImplementedError
64+
65+
@abc.abstractmethod
66+
def _batch(self, prompts: List[Prompt]):
67+
raise NotImplementedError
68+
69+
@abc.abstractmethod
70+
def _commit(self):
71+
raise NotImplementedError
72+
73+
@abc.abstractmethod
74+
def _get(self, id: str) -> Prompt:
75+
raise NotImplementedError
76+
77+
78+
class FsPromptRepository(AbstractPromptRepository):
79+
def __init__(self):
80+
super().__init__()
81+
self.fs = PathFs()
82+
83+
def rollback(self):
84+
# TODO: This is not what rollback is for. We should be rolling back any
85+
# file changes
86+
self.to_write = {}
87+
88+
def _add(self, prompt: Prompt):
89+
self.to_write[prompt.id] = prompt.content
90+
91+
def _batch(self, prompts: List[Prompt]):
92+
for batch_num, batch in enumerate(batched(prompts, DEFAULT_BATCH_SIZE)):
93+
batch_name = f"{BATCH_PREFIX}{(batch_num + 1):03}"
94+
for prompt in batch:
95+
prompt.id = f"{batch_name}/{prompt.id}"
96+
self._add(prompt)
97+
98+
def _commit(self):
99+
base_path = (
100+
Path(self.partition) if self.partition else Path(".")
101+
)
102+
103+
for file_path, content in self.to_write.items():
104+
if content:
105+
full_path = base_path / file_path
106+
self.fs.mkdir(full_path.parent)
107+
self.fs.write(full_path, content)
108+
109+
def _get(self, id: str):
110+
return Prompt(id, self.fs.read(Path(id)))
111+
112+
113+
class AbstractDocGenRepository(abc.ABC):
114+
def get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
115+
return self._get_new_prompts(doc_gen_root)
116+
117+
@abc.abstractmethod
118+
def _get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
119+
raise NotImplementedError
120+
121+
122+
class FsDocGenRepository(AbstractDocGenRepository):
123+
def rollback(self):
124+
# TODO: This is not what rollback is for. We should be rolling back any
125+
# file changes
126+
self._doc_gen = None
127+
128+
def _get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
129+
self._doc_gen = DocGen.from_root(Path(doc_gen_root))
130+
self._doc_gen.collect_snippets()
131+
new_examples = self._get_new_examples()
132+
prompts = self._examples_to_prompts(new_examples)
133+
return prompts
134+
135+
def _get_new_examples(self) -> List[Tuple[str, Example]]:
136+
examples = self._doc_gen.examples
137+
138+
filtered_examples: List[Tuple[str, Example]] = []
139+
for example_id, example in examples.items():
140+
# TCXContentAnalyzer prefixes new metadata title/title_abbrev entries with
141+
# the DEFAULT_METADATA_PREFIX. Checking this here to make sure we're only
142+
# running the LLM tool on new extractions.
143+
title = example.title or ""
144+
title_abbrev = example.title_abbrev or ""
145+
if title.startswith(DEFAULT_METADATA_PREFIX) and title_abbrev.startswith(
146+
DEFAULT_METADATA_PREFIX
147+
):
148+
filtered_examples.append((example_id, example))
149+
return filtered_examples
150+
151+
def _examples_to_prompts(self, examples: List[Tuple[str, Example]]) -> List[Prompt]:
152+
snippets = self._doc_gen.snippets
153+
prompts = []
154+
for example_id, example in examples:
155+
key = (
156+
example.languages[IAM_POLICY_LANGUAGE]
157+
.versions[0]
158+
.excerpts[0]
159+
.snippet_files[0]
160+
.replace("/", ".")
161+
)
162+
snippet = snippets.get(key)
163+
prompts.append(Prompt(f"{example_id}.md", snippet.code))
164+
return prompts

aws_doc_sdk_examples_tools/lliam/service_layer/__init__.py

Whitespace-only changes.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import logging
2+
3+
from aws_doc_sdk_examples_tools.lliam.domain.operations import build_ailly_config
4+
from aws_doc_sdk_examples_tools.lliam.domain.commands import CreatePrompts
5+
from aws_doc_sdk_examples_tools.lliam.service_layer.unit_of_work import (
6+
AbstractUnitOfWork,
7+
)
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def create_prompts(cmd: CreatePrompts, uow: AbstractUnitOfWork):
13+
with uow:
14+
system_prompts = uow.prompts.get_all(cmd.system_prompts)
15+
ailly_config = build_ailly_config(system_prompts)
16+
prompts = uow.doc_gen.get_new_prompts(cmd.doc_gen_root)
17+
uow.prompts.batch(prompts)
18+
uow.prompts.add(ailly_config)
19+
uow.prompts.set_partition(cmd.out_dir)
20+
uow.commit()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import abc
2+
3+
from aws_doc_sdk_examples_tools.lliam.adapters.repository import (
4+
AbstractPromptRepository,
5+
AbstractDocGenRepository,
6+
FsPromptRepository,
7+
FsDocGenRepository,
8+
)
9+
10+
11+
class AbstractUnitOfWork(abc.ABC):
12+
prompts: AbstractPromptRepository
13+
doc_gen: AbstractDocGenRepository
14+
15+
def __enter__(self) -> "AbstractUnitOfWork":
16+
return self
17+
18+
def __exit__(self, *args):
19+
self.rollback()
20+
21+
def commit(self):
22+
self._commit()
23+
24+
def collect_new_prompts(self):
25+
for prompt in self.prompts.seen:
26+
yield prompt
27+
28+
@abc.abstractmethod
29+
def _commit(self):
30+
raise NotImplementedError
31+
32+
@abc.abstractmethod
33+
def rollback(self):
34+
raise NotImplementedError
35+
36+
37+
class FsUnitOfWork(AbstractUnitOfWork):
38+
39+
def __enter__(self):
40+
self.prompts = FsPromptRepository()
41+
self.doc_gen = FsDocGenRepository()
42+
return super().__enter__()
43+
44+
def __exit__(self, *args):
45+
super().__exit__(*args)
46+
47+
def _commit(self):
48+
self.prompts.commit()
49+
50+
def rollback(self):
51+
self.prompts.rollback()
52+
self.doc_gen.rollback()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from pathlib import Path
2+
3+
AILLY_DIR = ".ailly_iam_policy"
4+
AILLY_DIR_PATH = Path(AILLY_DIR)
5+
BATCH_PREFIX = "batch_"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pytest
2+
from typing import List
3+
4+
from aws_doc_sdk_examples_tools.lliam.domain.commands import CreatePrompts
5+
from aws_doc_sdk_examples_tools.lliam.domain.model import Prompt
6+
from aws_doc_sdk_examples_tools.lliam.service_layer.create_prompts import create_prompts
7+
from aws_doc_sdk_examples_tools.lliam.service_layer.unit_of_work import AbstractUnitOfWork
8+
from aws_doc_sdk_examples_tools.lliam.adapters.repository import (
9+
AbstractPromptRepository,
10+
AbstractDocGenRepository,
11+
)
12+
13+
14+
class FakePromptRepository(AbstractPromptRepository):
15+
def __init__(self):
16+
super().__init__()
17+
self.committed_prompts = {}
18+
self.system_prompts = {
19+
"system1.md": Prompt("system1.md", "System prompt 1 content"),
20+
"system2.md": Prompt("system2.md", "System prompt 2 content"),
21+
}
22+
self.partition_name = None
23+
24+
def _add(self, prompt: Prompt):
25+
self.to_write[prompt.id] = prompt.content
26+
27+
def _batch(self, prompts: List[Prompt]):
28+
for i, prompt in enumerate(prompts):
29+
batch_name = f"batch{(i // 150) + 1:03}"
30+
prompt.id = f"{batch_name}/{prompt.id}"
31+
self._add(prompt)
32+
33+
def _commit(self):
34+
self.committed_prompts.update(self.to_write)
35+
self.to_write = {}
36+
37+
def _get(self, id: str) -> Prompt:
38+
if id in self.system_prompts:
39+
return self.system_prompts[id]
40+
raise KeyError(f"Prompt {id} not found")
41+
42+
def rollback(self):
43+
self.to_write = {}
44+
45+
46+
class FakeDocGenRepository(AbstractDocGenRepository):
47+
def __init__(self):
48+
self.mock_prompts = [
49+
Prompt("example1.md", "Example 1 code content"),
50+
Prompt("example2.md", "Example 2 code content"),
51+
]
52+
53+
def _get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
54+
return self.mock_prompts
55+
56+
def rollback(self):
57+
pass
58+
59+
60+
class FakeUnitOfWork(AbstractUnitOfWork):
61+
def __init__(self):
62+
self.prompts = FakePromptRepository()
63+
self.doc_gen = FakeDocGenRepository()
64+
self.committed = False
65+
66+
def _commit(self):
67+
self.committed = True
68+
self.prompts.commit()
69+
70+
def rollback(self):
71+
self.committed = False
72+
self.prompts.rollback()
73+
self.doc_gen.rollback()
74+
75+
76+
def test_create_prompts_writes_when_commit_called():
77+
"""Test that create_prompts successfully writes prompts when commit is called."""
78+
uow = FakeUnitOfWork()
79+
cmd = CreatePrompts(
80+
doc_gen_root="/fake/doc_gen_root",
81+
system_prompts=["system1.md", "system2.md"],
82+
out_dir="/fake/output"
83+
)
84+
85+
create_prompts(cmd, uow)
86+
87+
# Ailly config should be in committed prompts
88+
assert ".aillyrc" in uow.prompts.committed_prompts
89+
ailly_config_content = uow.prompts.committed_prompts[".aillyrc"]
90+
assert "System prompt 1 content" in ailly_config_content
91+
assert "System prompt 2 content" in ailly_config_content
92+
93+
# New prompts from DocGen should be in committed prompts
94+
assert "batch001/example1.md" in uow.prompts.committed_prompts
95+
assert "batch001/example2.md" in uow.prompts.committed_prompts
96+
assert uow.prompts.committed_prompts["batch001/example1.md"] == "Example 1 code content"
97+
assert uow.prompts.committed_prompts["batch001/example2.md"] == "Example 2 code content"

0 commit comments

Comments
 (0)