Skip to content

Commit 944f1fe

Browse files
authored
feat: python: implement spec test logic (#285)
* feat: python: implement spec test logic ISSUE: #66 CHANGELOG: [x] fixed yaml deserialization [x] implemented a test body * feat: python: implement spec test logic ISSUE: #66 CHANGELOG: [x] fixed yaml deserialization [x] implemented a test body * feat: python: implement spec test logic ISSUE: #66 CHANGELOG: [x] fixed yaml deserialization [x] implemented a test body
1 parent 687fbbd commit 944f1fe

File tree

2 files changed

+78
-64
lines changed

2 files changed

+78
-64
lines changed

python/dotpromptz/src/dotpromptz/dotprompt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,7 @@ async def __call__(
152152
Returns:
153153
The rendered prompt.
154154
"""
155-
# Discard the input schema as once rendered it doesn't make sense.
156155
merged_metadata: PromptMetadata[ModelConfigT] = await self._dotprompt.render_metadata(self.prompt, options)
157-
merged_metadata.input = None
158156

159157
# Prepare input data, merging defaults from options if available.
160158
context: Context = {
@@ -381,7 +379,7 @@ async def _resolve_metadata(
381379
out = base.model_copy(deep=True)
382380

383381
for merge in merges:
384-
if merge is not None:
382+
if merge:
385383
out = _merge_metadata(out, merge)
386384

387385
# Remove the template attribute if it exists (TS does this).

python/dotpromptz/tests/dotpromptz/spec_test.py

Lines changed: 77 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,27 @@
9999

100100
from __future__ import annotations
101101

102+
import asyncio
102103
import re
103104
import unittest
104105
from collections.abc import Callable, Coroutine
105106
from pathlib import Path
106-
from typing import Any, TypedDict
107+
from typing import Any, Generic, TypedDict
107108

108109
import structlog
109110
import yaml
111+
from pydantic import BaseModel, Field
110112

111113
from dotpromptz.dotprompt import Dotprompt
112-
from dotpromptz.typing import DataArgument, JsonSchema, ToolDefinition
114+
from dotpromptz.typing import (
115+
DataArgument,
116+
JsonSchema,
117+
Message,
118+
ModelConfigT,
119+
PromptInputConfig,
120+
PromptMetadata,
121+
ToolDefinition,
122+
)
113123

114124
logger = structlog.get_logger(__name__)
115125

@@ -121,16 +131,18 @@
121131
# List of files that are allowed to be used as spec files.
122132
# Useful for debugging and testing.
123133
ALLOWLISTED_FILES = [
134+
# TODO(#284): most of commented out tests are failing because of issues with
135+
# the handelbarz implementation.
124136
'spec/helpers/history.yaml',
125-
'spec/helpers/ifEquals.yaml',
126-
'spec/helpers/json.yaml',
137+
# 'spec/helpers/ifEquals.yaml',
138+
# 'spec/helpers/json.yaml',
127139
'spec/helpers/media.yaml',
128140
'spec/helpers/role.yaml',
129-
'spec/helpers/section.yaml',
130-
'spec/helpers/unlessEquals.yaml',
131-
'spec/metadata.yaml',
132-
'spec/partials.yaml',
133-
'spec/picoschema.yaml',
141+
# 'spec/helpers/section.yaml',
142+
# 'spec/helpers/unlessEquals.yaml',
143+
# 'spec/metadata.yaml',
144+
# 'spec/partials.yaml',
145+
# 'spec/picoschema.yaml',
134146
'spec/variables.yaml',
135147
]
136148

@@ -139,37 +151,37 @@
139151
test_case_counter = 0
140152

141153

142-
class Expect(TypedDict, total=False):
154+
class Expect(BaseModel):
143155
"""An expectation for the spec."""
144156

145-
config: bool
146-
ext: bool
147-
input: bool
148-
messages: bool
149-
metadata: bool
150-
raw: bool
157+
config: dict[Any, Any] = Field(default_factory=dict)
158+
ext: dict[str, dict[str, Any]] = Field(default_factory=dict)
159+
input: PromptInputConfig | None = None
160+
messages: list[Message] = Field(default_factory=list)
161+
metadata: dict[str, Any] = Field(default_factory=dict)
162+
raw: dict[str, Any] | None = None
151163

152164

153-
class SpecTest(TypedDict, total=False):
165+
class SpecTest(BaseModel, Generic[ModelConfigT]):
154166
"""A test case for a YAML spec."""
155167

156-
desc: str
157-
data: DataArgument[Any]
168+
desc: str = Field(default='UnnamedTest')
169+
data: DataArgument[Any] | None = None
158170
expect: Expect
159-
options: dict[str, Any]
171+
options: PromptMetadata[ModelConfigT] | None = None
160172

161173

162-
class SpecSuite(TypedDict, total=False):
174+
class SpecSuite(BaseModel, Generic[ModelConfigT]):
163175
"""A suite of test cases for a YAML spec."""
164176

165-
name: str
177+
name: str = Field(default='UnnamedSuite')
166178
template: str
167-
data: DataArgument[Any]
168-
schemas: dict[str, JsonSchema]
169-
tools: dict[str, ToolDefinition]
170-
partials: dict[str, str]
171-
resolver_partials: dict[str, str]
172-
tests: list[SpecTest]
179+
data: DataArgument[Any] | None = None
180+
schemas: dict[str, JsonSchema] | None = None
181+
tools: dict[str, ToolDefinition] | None = None
182+
partials: dict[str, str] = Field(default_factory=dict)
183+
resolver_partials: dict[str, str] = Field(default_factory=dict)
184+
tests: list[SpecTest[ModelConfigT]] = Field(default_factory=list)
173185

174186

175187
def is_allowed_spec_file(file: Path) -> bool:
@@ -236,7 +248,7 @@ def make_test_class_name(yaml_file_name: str, suite_name: str | None) -> str:
236248
return f'Test_{file_part}_{suite_part}Suite'
237249

238250

239-
def make_dotprompt_for_suite(suite: SpecSuite) -> Dotprompt:
251+
def make_dotprompt_for_suite(suite: SpecSuite[ModelConfigT]) -> Dotprompt:
240252
"""Constructs and sets up a Dotprompt instance for the given suite.
241253
242254
Args:
@@ -245,19 +257,19 @@ def make_dotprompt_for_suite(suite: SpecSuite) -> Dotprompt:
245257
Returns:
246258
A Dotprompt instance.
247259
"""
248-
resolver_partials_from_suite: dict[str, str] = suite.get('resolver_partials', {})
260+
resolver_partials_from_suite: dict[str, str] = suite.resolver_partials
249261

250262
def partial_resolver_fn(name: str) -> str | None:
251263
return resolver_partials_from_suite.get(name)
252264

253265
dotprompt = Dotprompt(
254-
schemas=suite.get('schemas'),
255-
tools=suite.get('tools'),
256-
partial_resolver=partial_resolver_fn if resolver_partials_from_suite else None,
266+
schemas=suite.schemas,
267+
tools=suite.tools,
268+
partial_resolver=partial_resolver_fn if suite.resolver_partials else None,
257269
)
258270

259271
# Register partials directly defined in the suite
260-
defined_partials: dict[str, str] = suite.get('partials', {})
272+
defined_partials: dict[str, str] = suite.partials
261273
for name, template_content in defined_partials.items():
262274
dotprompt.define_partial(name, template_content)
263275

@@ -284,10 +296,12 @@ def test_spec_files_are_valid(self) -> None:
284296
self.assertIsNotNone(data)
285297

286298

287-
class YamlSpecTestBase(unittest.IsolatedAsyncioTestCase):
299+
class YamlSpecTestBase(unittest.IsolatedAsyncioTestCase, Generic[ModelConfigT]):
288300
"""A base class that is used as a template for all YAML spec test suites."""
289301

290-
async def run_yaml_test(self, yaml_file: Path, suite: SpecSuite, test_case: SpecTest) -> None:
302+
async def run_yaml_test(
303+
self, yaml_file: Path, suite: SpecSuite[ModelConfigT], test_case: SpecTest[ModelConfigT]
304+
) -> None:
291305
"""Runs a YAML test.
292306
293307
Args:
@@ -298,15 +312,24 @@ async def run_yaml_test(self, yaml_file: Path, suite: SpecSuite, test_case: Spec
298312
Returns:
299313
None.
300314
"""
301-
suite_name = suite.get('name', 'UnnamedSuite')
302-
test_desc = test_case.get('desc', 'UnnamedTest')
303-
logger.info(f'[TEST] {yaml_file.stem} > {suite_name} > {test_desc}')
315+
logger.info(f'[TEST] {yaml_file.stem} > {suite.name} > {test_case.desc}')
304316

305317
# Create test-specific dotprompt instance.
306318
dotprompt = make_dotprompt_for_suite(suite)
307319
self.assertIsNotNone(dotprompt)
308320

309-
# TODO: Add test logic here.
321+
data = self._merge_data(suite.data or DataArgument[Any](), test_case.data or DataArgument[Any]())
322+
result = await dotprompt.render(suite.template, data, test_case.options)
323+
pruned_res: Expect = Expect(**result.model_dump())
324+
self.assertEqual(pruned_res, test_case.expect)
325+
326+
def _merge_data(self, data1: DataArgument[Any], data2: DataArgument[Any]) -> DataArgument[Any]:
327+
merged = DataArgument[Any]()
328+
merged.input = data1.input or data2.input
329+
merged.docs = (data1.docs or []) + (data2.docs or [])
330+
merged.messages = (data1.messages or []) + (data2.messages or [])
331+
merged.context = {**(data1.context or {}), **(data1.context or {})}
332+
return merged
310333

311334

312335
def make_suite_class_name(yaml_file: Path, suite_name: str | None) -> str:
@@ -347,9 +370,9 @@ def make_test_case_name(yaml_file: Path, suite_name: str, test_desc: str) -> str
347370

348371
def make_async_test_case_method(
349372
yaml_file: Path,
350-
suite: SpecSuite,
351-
test_case: SpecTest,
352-
) -> Callable[[YamlSpecTestBase], Coroutine[Any, Any, None]]:
373+
suite: SpecSuite[ModelConfigT],
374+
test_case: SpecTest[ModelConfigT],
375+
) -> Callable[[YamlSpecTestBase[ModelConfigT]], Coroutine[Any, Any, None]]:
353376
"""Creates an async test method for a test case.
354377
355378
Args:
@@ -361,7 +384,7 @@ def make_async_test_case_method(
361384
An async test method.
362385
"""
363386

364-
async def test_method(self_dynamic: YamlSpecTestBase) -> None:
387+
async def test_method(self_dynamic: YamlSpecTestBase[ModelConfigT]) -> None:
365388
"""An async test method."""
366389
await self_dynamic.run_yaml_test(yaml_file, suite, test_case)
367390

@@ -370,7 +393,7 @@ async def test_method(self_dynamic: YamlSpecTestBase) -> None:
370393

371394
def make_async_skip_test_method(
372395
yaml_file: Path, suite_name: str
373-
) -> Callable[[YamlSpecTestBase], Coroutine[Any, Any, None]]:
396+
) -> Callable[[YamlSpecTestBase[ModelConfigT]], Coroutine[Any, Any, None]]:
374397
"""Creates a skip test for a suite.
375398
376399
Args:
@@ -381,7 +404,7 @@ def make_async_skip_test_method(
381404
A skip test.
382405
"""
383406

384-
async def skip_method(self_dynamic: YamlSpecTestBase) -> None:
407+
async def skip_method(self_dynamic: YamlSpecTestBase[ModelConfigT]) -> None:
385408
self_dynamic.skipTest(f"Suite '{suite_name}' in {yaml_file.stem} has no tests.")
386409

387410
return skip_method
@@ -417,29 +440,22 @@ def generate_test_suites(files: list[Path]) -> None:
417440
# Iterate over the suites in the YAML file and ensure it has a name.
418441
for suite_data in suites_data:
419442
# Normalize the suite data to ensure it has a name.
420-
suite: SpecSuite = suite_data
421-
suite_name = suite.get('name', f'UnnamedSuite_{yaml_file.stem}')
422-
suite['name'] = suite_name
443+
suite = SpecSuite(**suite_data)
444+
suite.name = suite.name or f'UnnamedSuite_{yaml_file.stem}'
423445

424446
# Create the dynamic test class for the suite.
425-
class_name = make_suite_class_name(yaml_file, suite_name)
447+
class_name = make_suite_class_name(yaml_file, suite.name)
426448
klass = type(class_name, (YamlSpecTestBase,), {})
427449

428450
# Skip the suite if it has no tests.
429-
test_cases = suite.get('tests', [])
430-
if not test_cases:
431-
klass.test_empty_suite = make_async_skip_test_method(yaml_file, suite_name) # type: ignore[attr-defined]
451+
if not suite.tests:
452+
klass.test_empty_suite = make_async_skip_test_method(yaml_file, suite.name) # type: ignore[attr-defined]
432453

433454
# Iterate over the tests in the suite and add them to the class.
434-
for tc_raw in test_cases:
435-
# Normalize the test case data to ensure it has a name.
436-
tc: SpecTest = tc_raw
437-
tc_name = tc.get('desc', 'UnnamedTest')
438-
tc['desc'] = tc_name
439-
455+
for tc in suite.tests:
440456
# Create the test case method and add it to the class.
441-
test_case_name = make_test_case_name(yaml_file, suite_name, tc_name)
442-
test_method = make_async_test_case_method(yaml_file, suite, tc_raw)
457+
test_case_name = make_test_case_name(yaml_file, suite.name, tc.desc)
458+
test_method = make_async_test_case_method(yaml_file, suite, tc)
443459
setattr(klass, test_case_name, test_method)
444460

445461
# Add the test suite class to the module globals.

0 commit comments

Comments
 (0)