Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/adapters/repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import abc
from itertools import islice
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Tuple

from aws_doc_sdk_examples_tools.doc_gen import DocGen, Example
from aws_doc_sdk_examples_tools.fs import Fs, PathFs
from aws_doc_sdk_examples_tools.lliam.domain.model import Prompt
from aws_doc_sdk_examples_tools.lliam.shared_constants import BATCH_PREFIX

DEFAULT_METADATA_PREFIX = "DEFAULT"
DEFAULT_BATCH_SIZE = 150
IAM_POLICY_LANGUAGE = "IAMPolicyGrammar"


def batched(iterable: Iterable, n: int) -> Generator[Tuple, Any, None]:
"Batch data into tuples of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch


class FsPromptRepository:
to_write: Dict[str, str] = {}

def __init__(self, fs: Fs = PathFs()):
self.fs = fs

def rollback(self):
# TODO: This is not what rollback is for. We should be rolling back any
# file changes
self.to_write = {}

def add(self, prompt: Prompt):
self.to_write[prompt.id] = prompt.content

def all_all(self, prompts: List[Prompt]):
for prompt in prompts:
self.add(prompt)

def batch(self, prompts: List[Prompt]):
for batch_num, batch in enumerate(batched(prompts, DEFAULT_BATCH_SIZE)):
batch_name = f"{BATCH_PREFIX}{(batch_num + 1):03}"
for prompt in batch:
prompt.id = f"{batch_name}/{prompt.id}"
self.add(prompt)

def commit(self):
base_path = Path(self.partition) if self.partition else Path(".")

for file_path, content in self.to_write.items():
if content:
full_path = base_path / file_path
self.fs.mkdir(full_path.parent)
self.fs.write(full_path, content)

def get(self, id: str):
return Prompt(id, self.fs.read(Path(id)))

def get_all(self, ids: List[str]) -> List[Prompt]:
prompts = []
for id in ids:
prompt = self.get(id)
prompts.append(prompt)
return prompts

def set_partition(self, name: str):
self.partition_name = name

@property
def partition(self):
return self.partition_name or ""


class FsDocGenRepository:
def __init__(self, fs: Fs = PathFs()):
self.fs = fs

def rollback(self):
# TODO: This is not what rollback is for. We should be rolling back any
# file changes
self._doc_gen = None

def get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
# Right now this is the only instance of DocGen used in this Repository,
# but if that changes we need to move it up.
self._doc_gen = DocGen.from_root(Path(doc_gen_root), fs=self.fs)
self._doc_gen.collect_snippets()
new_examples = self._get_new_examples()
prompts = self._examples_to_prompts(new_examples)
return prompts

def _get_new_examples(self) -> List[Tuple[str, Example]]:
examples = self._doc_gen.examples

filtered_examples: List[Tuple[str, Example]] = []
for example_id, example in examples.items():
# TCXContentAnalyzer prefixes new metadata title/title_abbrev entries with
# the DEFAULT_METADATA_PREFIX. Checking this here to make sure we're only
# running the LLM tool on new extractions.
title = example.title or ""
title_abbrev = example.title_abbrev or ""
if title.startswith(DEFAULT_METADATA_PREFIX) and title_abbrev.startswith(
DEFAULT_METADATA_PREFIX
):
filtered_examples.append((example_id, example))
return filtered_examples

def _examples_to_prompts(self, examples: List[Tuple[str, Example]]) -> List[Prompt]:
snippets = self._doc_gen.snippets
prompts = []
for example_id, example in examples:
key = (
example.languages[IAM_POLICY_LANGUAGE]
.versions[0]
.excerpts[0]
.snippet_files[0]
.replace("/", ".")
)
snippet = snippets.get(key)
prompts.append(Prompt(f"{example_id}.md", snippet.code))
return prompts
2 changes: 1 addition & 1 deletion aws_doc_sdk_examples_tools/lliam/domain/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Command:

@dataclass
class CreatePrompts(Command):
doc_gen_root: str
doc_gen_root: str
system_prompts: List[str]
out_dir: str

Expand Down
Empty file.
18 changes: 18 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/service_layer/create_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import logging

from aws_doc_sdk_examples_tools.lliam.domain.operations import build_ailly_config
from aws_doc_sdk_examples_tools.lliam.domain.commands import CreatePrompts
from aws_doc_sdk_examples_tools.lliam.service_layer.unit_of_work import FsUnitOfWork

logger = logging.getLogger(__name__)


def create_prompts(cmd: CreatePrompts, uow: FsUnitOfWork):
with uow:
system_prompts = uow.prompts.get_all(cmd.system_prompts)
ailly_config = build_ailly_config(system_prompts)
prompts = uow.doc_gen.get_new_prompts(cmd.doc_gen_root)
uow.prompts.batch(prompts)
uow.prompts.add(ailly_config)
uow.prompts.set_partition(cmd.out_dir)
uow.commit()
28 changes: 28 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/service_layer/unit_of_work.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from aws_doc_sdk_examples_tools.fs import Fs
from aws_doc_sdk_examples_tools.lliam.adapters.repository import (
FsPromptRepository,
FsDocGenRepository,
)


class FsUnitOfWork:
def __init__(self, fs: Fs):
self.fs = fs

def __enter__(self):
self.prompts = FsPromptRepository(fs=self.fs)
self.doc_gen = FsDocGenRepository(fs=self.fs)

def __exit__(self, *args):
self.rollback()

def commit(self):
self.prompts.commit()

def collect_new_prompts(self):
for prompt in self.prompts.seen:
yield prompt

def rollback(self):
self.prompts.rollback()
self.doc_gen.rollback()
5 changes: 5 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/shared_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pathlib import Path

AILLY_DIR = ".ailly_iam_policy"
AILLY_DIR_PATH = Path(AILLY_DIR)
BATCH_PREFIX = "batch_"
32 changes: 32 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/test/services_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path

from aws_doc_sdk_examples_tools.fs import RecordFs
from aws_doc_sdk_examples_tools.lliam.domain.commands import CreatePrompts
from aws_doc_sdk_examples_tools.lliam.service_layer.create_prompts import create_prompts
from aws_doc_sdk_examples_tools.lliam.service_layer.unit_of_work import FsUnitOfWork


def test_create_prompts_writes_when_commit_called():
"""Test that create_prompts successfully writes prompts when commit is called."""
fs = RecordFs(
{
Path("system1.md"): "System prompt 1 content",
Path("system2.md"): "System prompt 2 content",
Path("fake/doc_gen_root"): "",
}
)
uow = FsUnitOfWork(fs=fs)
cmd = CreatePrompts(
doc_gen_root="fake/doc_gen_root",
system_prompts=["system1.md", "system2.md"],
out_dir="/fake/output",
)

create_prompts(cmd, uow)

# Ailly config should be in committed prompts
ailly_config_path = Path("/fake/output/.aillyrc")
assert fs.stat(ailly_config_path).exists
ailly_config_content = fs.read(ailly_config_path)
assert "System prompt 1 content" in ailly_config_content
assert "System prompt 2 content" in ailly_config_content