Skip to content

Commit d5936cb

Browse files
committed
pattern class init
1 parent f92a286 commit d5936cb

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

mellea/stdlib/pattern.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
2+
from typing import Any
3+
from dataclasses import dataclass
4+
from string import Template
5+
from pydantic import BaseModel, create_model
6+
7+
class Pattern:
8+
9+
def __init__(self, pattern:str, **types:type[Any]):
10+
# e.g.,
11+
self.pattern = Template(pattern)
12+
self.types = types
13+
14+
15+
def defaulted_types(self) -> dict[str, type[Any]]:
16+
# defaults the variable type to str.
17+
# we do not do this defaulting in the constructor because
18+
# composition may add in additional type information.
19+
# in other words, type defaults are performed in the very last step.
20+
21+
return {
22+
v : (
23+
self.types[v]
24+
if v in self.types
25+
else str
26+
)
27+
for variable in self.pattern.get_identifiers()
28+
}
29+
30+
31+
def model(self) -> type[BaseModel]:
32+
33+
return create_model("TemplateModel", **self.defaulted_types()) # type: ignore
34+
35+
36+
def format(self) -> str:
37+
38+
return (
39+
"Fill in the blank in the following template, " +
40+
"where the placeholders are denoted as $var for a variable named 'var'. " +
41+
"Answer in a json schema shown after the template. " +
42+
"\nTemplate:\n" +
43+
self.pattern.template
44+
"\nSchema:\n" +
45+
self.model().dump_json_schema()
46+
)
47+
48+
49+
if __name__ == "__main__":
50+
51+
from mellea import start_session, SimpleContext
52+
m = start_session()
53+
p = Pattern("the height of $mountain is $height meters.",
54+
mountain=str,
55+
height=int)
56+
57+
print(m.instruct(p.format(),
58+
format=p.model()))
59+

0 commit comments

Comments
 (0)