From 6ded8e6fbfc6e966915dc2cdc217e2bd3500cee2 Mon Sep 17 00:00:00 2001 From: Louis Mandel Date: Fri, 6 Jun 2025 15:26:37 -0400 Subject: [PATCH] fix: trace generation with context Signed-off-by: Louis Mandel --- src/pdl/pdl_context.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/pdl/pdl_context.py b/src/pdl/pdl_context.py index e59b0e080..941df2abd 100644 --- a/src/pdl/pdl_context.py +++ b/src/pdl/pdl_context.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from collections.abc import Sequence from enum import StrEnum from typing import Any, Callable @@ -21,10 +22,10 @@ class SerializeMode(StrEnum): GRANITEIO = "graniteio" -class PDLContext(Sequence): +class PDLContext(ABC, Sequence): - def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]: - return [] + @abstractmethod + def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]: ... def __add__(self, value: "PDLContext"): return IndependentContext([self, value]) @@ -32,12 +33,6 @@ def __add__(self, value: "PDLContext"): def __mul__(self, value: "PDLContext"): return DependentContext([self, value]) - def __len__(self): - return 0 - - def __getitem__(self, index: int | slice): # pyright: ignore - return [] - # def to_json(self): # return json.dumps(self.serialize(SerializeMode.LITELLM)) @@ -52,16 +47,14 @@ def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]: result = self.message.result() return [result] - def __len__(self): # pyright: ignore + def __len__(self): return 1 def __getitem__(self, index: int | slice): # pyright: ignore - if index in (0, -1): - return self.message.result() - assert False + return [self.message.result()][index] def __repr__(self): # pyright: ignore - return str(self.message.result()) + return self.message.result().__repr__() class IndependentContext(PDLContext): @@ -74,6 +67,8 @@ def __init__(self, context: list[PDLContext]): ret += item.context.data elif isinstance(item, SingletonContext): ret += [item] + elif isinstance(item, DependentContext) and len(item) == 0: + pass else: # Not all elements of the list are Independent, so return self.context = PdlList(context) @@ -111,6 +106,8 @@ def __init__(self, context: list[PDLContext]): ret += item.context.data elif isinstance(item, SingletonContext): ret += [item] + elif isinstance(item, IndependentContext) and len(item) == 0: + pass else: # Not all elements of the list are Dependent, so return self.context = PdlList(context)