Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/adapters/repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from itertools import islice
from math import ceil, floor, log
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Tuple

from aws_doc_sdk_examples_tools.doc_gen import DocGen, Example
from aws_doc_sdk_examples_tools.fs import Fs, PathFs
from aws_doc_sdk_examples_tools.lliam.domain.model import Prompt
from aws_doc_sdk_examples_tools.lliam.config import BATCH_PREFIX

DEFAULT_METADATA_PREFIX = "DEFAULT"
DEFAULT_BATCH_SIZE = 150
IAM_POLICY_LANGUAGE = "IAMPolicyGrammar"


def batched(iterable: Iterable, n: int) -> Generator[Tuple, Any, None]:
"Batch data into tuples of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch


class PromptRepository:
to_write: Dict[str, str] = {}

def __init__(self, fs: Fs = PathFs()):
self.fs = fs

def rollback(self):
# TODO: This is not what rollback is for. We should be rolling back any
# file changes
self.to_write = {}

def add(self, prompt: Prompt):
self.to_write[prompt.id] = prompt.content

def all_all(self, prompts: Iterable[Prompt]):
for prompt in prompts:
self.add(prompt)

def batch(self, prompts: Iterable[Prompt]):
prompt_list = list(prompts)

if not prompt_list:
return

batches_count = ceil(len(prompt_list) / DEFAULT_BATCH_SIZE)
padding = floor(log(batches_count, 10)) + 1
for batch_num, batch in enumerate(batched(prompts, DEFAULT_BATCH_SIZE)):
batch_name = f"{BATCH_PREFIX}{(batch_num + 1):0{padding}}"
for prompt in batch:
prompt.id = f"{batch_name}/{prompt.id}"
self.add(prompt)

def commit(self):
base_path = Path(self.partition) if self.partition else Path(".")

for id, content in self.to_write.items():
if content:
full_path = base_path / id
self.fs.mkdir(full_path.parent)
self.fs.write(full_path, content)

def get(self, id: str):
return Prompt(id, self.fs.read(Path(id)))

def get_all(self, ids: List[str]) -> List[Prompt]:
prompts = []
for id in ids:
prompt = self.get(id)
prompts.append(prompt)
return prompts

def set_partition(self, name: str):
self.partition_name = name

@property
def partition(self):
return self.partition_name or ""


class DocGenRepository:
def __init__(self, fs: Fs = PathFs()):
self.fs = fs

def rollback(self):
# TODO: This is not what rollback is for. We should be rolling back any
# file changes
self._doc_gen = None

def get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
# Right now this is the only instance of DocGen used in this Repository,
# but if that changes we need to move it up.
self._doc_gen = DocGen.from_root(Path(doc_gen_root), fs=self.fs)
self._doc_gen.collect_snippets()
new_examples = self._get_new_examples()
prompts = self._examples_to_prompts(new_examples)
return prompts

def _get_new_examples(self) -> List[Tuple[str, Example]]:
examples = self._doc_gen.examples

filtered_examples: List[Tuple[str, Example]] = []
for example_id, example in examples.items():
# TCXContentAnalyzer prefixes new metadata title/title_abbrev entries with
# the DEFAULT_METADATA_PREFIX. Checking this here to make sure we're only
# running the LLM tool on new extractions.
title = example.title or ""
title_abbrev = example.title_abbrev or ""
if title.startswith(DEFAULT_METADATA_PREFIX) and title_abbrev.startswith(
DEFAULT_METADATA_PREFIX
):
filtered_examples.append((example_id, example))
return filtered_examples

def _examples_to_prompts(self, examples: List[Tuple[str, Example]]) -> List[Prompt]:
snippets = self._doc_gen.snippets
prompts = []
for example_id, example in examples:
key = (
example.languages[IAM_POLICY_LANGUAGE]
.versions[0]
.excerpts[0]
.snippet_files[0]
.replace("/", ".")
)
snippet = snippets.get(key)
prompts.append(Prompt(f"{example_id}.md", snippet.code))
return prompts
5 changes: 5 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pathlib import Path

AILLY_DIR = ".ailly_iam_policy"
AILLY_DIR_PATH = Path(AILLY_DIR)
BATCH_PREFIX = "batch_"
2 changes: 1 addition & 1 deletion aws_doc_sdk_examples_tools/lliam/domain/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Command:

@dataclass
class CreatePrompts(Command):
doc_gen_root: str
doc_gen_root: str
system_prompts: List[str]
out_dir: str

Expand Down
Empty file.
18 changes: 18 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/service_layer/create_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import logging

from aws_doc_sdk_examples_tools.lliam.domain.operations import build_ailly_config
from aws_doc_sdk_examples_tools.lliam.domain.commands import CreatePrompts
from aws_doc_sdk_examples_tools.lliam.service_layer.unit_of_work import FsUnitOfWork

logger = logging.getLogger(__name__)


def create_prompts(cmd: CreatePrompts, uow: FsUnitOfWork):
with uow:
system_prompts = uow.prompts.get_all(cmd.system_prompts)
ailly_config = build_ailly_config(system_prompts)
prompts = uow.doc_gen.get_new_prompts(cmd.doc_gen_root)
uow.prompts.batch(prompts)
uow.prompts.add(ailly_config)
uow.prompts.set_partition(cmd.out_dir)
uow.commit()
24 changes: 24 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/service_layer/unit_of_work.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from aws_doc_sdk_examples_tools.fs import Fs, PathFs
from aws_doc_sdk_examples_tools.lliam.adapters.repository import (
PromptRepository,
DocGenRepository,
)


class FsUnitOfWork:
def __init__(self, fs: Fs = PathFs()):
self.fs = fs

def __enter__(self):
self.prompts = PromptRepository(fs=self.fs)
self.doc_gen = DocGenRepository(fs=self.fs)

def __exit__(self, *args):
self.rollback()

def commit(self):
self.prompts.commit()

def rollback(self):
self.prompts.rollback()
self.doc_gen.rollback()
31 changes: 31 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/test/repository_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path

from aws_doc_sdk_examples_tools.fs import RecordFs
from aws_doc_sdk_examples_tools.lliam.adapters.repository import PromptRepository
from aws_doc_sdk_examples_tools.lliam.domain.model import Prompt


def test_batch_naming_occurs_properly():
"""Test that batch naming occurs properly when batching prompts."""
fs = RecordFs({})
repo = PromptRepository(fs=fs)

prompts = []
for i in range(300):
prompts.append(Prompt(f"prompt_{i}.md", f"Content for prompt {i}"))

repo.batch(prompts)

expected_batch_1_prompts = 150
expected_batch_2_prompts = 150

batch_1_count = 0
batch_2_count = 0
for prompt_id in repo.to_write:
if prompt_id.startswith("batch_1/"):
batch_1_count += 1
elif prompt_id.startswith("batch_2/"):
batch_2_count += 1

assert batch_1_count == expected_batch_1_prompts
assert batch_2_count == expected_batch_2_prompts
32 changes: 32 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/test/services_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path

from aws_doc_sdk_examples_tools.fs import RecordFs
from aws_doc_sdk_examples_tools.lliam.domain.commands import CreatePrompts
from aws_doc_sdk_examples_tools.lliam.service_layer.create_prompts import create_prompts
from aws_doc_sdk_examples_tools.lliam.service_layer.unit_of_work import FsUnitOfWork


def test_create_prompts_writes_when_commit_called():
"""Test that create_prompts successfully writes prompts when commit is called."""
fs = RecordFs(
{
Path("/system1.md"): "System prompt 1 content",
Path("/system2.md"): "System prompt 2 content",
Path("/fake/doc_gen_root"): "",
}
)
uow = FsUnitOfWork(fs=fs)
cmd = CreatePrompts(
doc_gen_root="/fake/doc_gen_root",
system_prompts=["/system1.md", "/system2.md"],
out_dir="/fake/output",
)

create_prompts(cmd, uow)

# Ailly config should be in committed prompts
ailly_config_path = Path("/fake/output/.aillyrc")
assert fs.stat(ailly_config_path).exists
ailly_config_content = fs.read(ailly_config_path)
assert "System prompt 1 content" in ailly_config_content
assert "System prompt 2 content" in ailly_config_content