Skip to content

Commit 50a3fac

Browse files
committed
Make batch num padding dynamic, and clean up a few things.
1 parent d40e692 commit 50a3fac

File tree

4 files changed

+25
-22
lines changed

4 files changed

+25
-22
lines changed

aws_doc_sdk_examples_tools/lliam/adapters/repository.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import abc
21
from itertools import islice
2+
from math import ceil, floor, log
33
from pathlib import Path
44
from typing import Any, Dict, Generator, Iterable, List, Tuple
55

66
from aws_doc_sdk_examples_tools.doc_gen import DocGen, Example
77
from aws_doc_sdk_examples_tools.fs import Fs, PathFs
88
from aws_doc_sdk_examples_tools.lliam.domain.model import Prompt
9-
from aws_doc_sdk_examples_tools.lliam.shared_constants import BATCH_PREFIX
9+
from aws_doc_sdk_examples_tools.lliam.config import BATCH_PREFIX
1010

1111
DEFAULT_METADATA_PREFIX = "DEFAULT"
1212
DEFAULT_BATCH_SIZE = 150
@@ -23,7 +23,7 @@ def batched(iterable: Iterable, n: int) -> Generator[Tuple, Any, None]:
2323
yield batch
2424

2525

26-
class FsPromptRepository:
26+
class PromptRepository:
2727
to_write: Dict[str, str] = {}
2828

2929
def __init__(self, fs: Fs = PathFs()):
@@ -37,23 +37,30 @@ def rollback(self):
3737
def add(self, prompt: Prompt):
3838
self.to_write[prompt.id] = prompt.content
3939

40-
def all_all(self, prompts: List[Prompt]):
40+
def all_all(self, prompts: Iterable[Prompt]):
4141
for prompt in prompts:
4242
self.add(prompt)
4343

44-
def batch(self, prompts: List[Prompt]):
44+
def batch(self, prompts: Iterable[Prompt]):
45+
prompt_list = list(prompts)
46+
47+
if not prompt_list:
48+
return
49+
50+
batches_count = ceil(len(prompt_list) / DEFAULT_BATCH_SIZE)
51+
padding = floor(log(batches_count, 10)) + 1
4552
for batch_num, batch in enumerate(batched(prompts, DEFAULT_BATCH_SIZE)):
46-
batch_name = f"{BATCH_PREFIX}{(batch_num + 1):03}"
53+
batch_name = f"{BATCH_PREFIX}{(batch_num + 1):0{padding}}"
4754
for prompt in batch:
4855
prompt.id = f"{batch_name}/{prompt.id}"
4956
self.add(prompt)
5057

5158
def commit(self):
5259
base_path = Path(self.partition) if self.partition else Path(".")
5360

54-
for file_path, content in self.to_write.items():
61+
for id, content in self.to_write.items():
5562
if content:
56-
full_path = base_path / file_path
63+
full_path = base_path / id
5764
self.fs.mkdir(full_path.parent)
5865
self.fs.write(full_path, content)
5966

@@ -75,7 +82,7 @@ def partition(self):
7582
return self.partition_name or ""
7683

7784

78-
class FsDocGenRepository:
85+
class DocGenRepository:
7986
def __init__(self, fs: Fs = PathFs()):
8087
self.fs = fs
8188

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from aws_doc_sdk_examples_tools.fs import Fs
22
from aws_doc_sdk_examples_tools.lliam.adapters.repository import (
3-
FsPromptRepository,
4-
FsDocGenRepository,
3+
PromptRepository,
4+
DocGenRepository,
55
)
66

77

@@ -10,19 +10,15 @@ def __init__(self, fs: Fs):
1010
self.fs = fs
1111

1212
def __enter__(self):
13-
self.prompts = FsPromptRepository(fs=self.fs)
14-
self.doc_gen = FsDocGenRepository(fs=self.fs)
13+
self.prompts = PromptRepository(fs=self.fs)
14+
self.doc_gen = DocGenRepository(fs=self.fs)
1515

1616
def __exit__(self, *args):
1717
self.rollback()
1818

1919
def commit(self):
2020
self.prompts.commit()
2121

22-
def collect_new_prompts(self):
23-
for prompt in self.prompts.seen:
24-
yield prompt
25-
2622
def rollback(self):
2723
self.prompts.rollback()
2824
self.doc_gen.rollback()

aws_doc_sdk_examples_tools/lliam/test/services_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ def test_create_prompts_writes_when_commit_called():
1010
"""Test that create_prompts successfully writes prompts when commit is called."""
1111
fs = RecordFs(
1212
{
13-
Path("system1.md"): "System prompt 1 content",
14-
Path("system2.md"): "System prompt 2 content",
15-
Path("fake/doc_gen_root"): "",
13+
Path("/system1.md"): "System prompt 1 content",
14+
Path("/system2.md"): "System prompt 2 content",
15+
Path("/fake/doc_gen_root"): "",
1616
}
1717
)
1818
uow = FsUnitOfWork(fs=fs)
1919
cmd = CreatePrompts(
20-
doc_gen_root="fake/doc_gen_root",
21-
system_prompts=["system1.md", "system2.md"],
20+
doc_gen_root="/fake/doc_gen_root",
21+
system_prompts=["/system1.md", "/system2.md"],
2222
out_dir="/fake/output",
2323
)
2424

0 commit comments

Comments
 (0)