diff --git a/src/pdl/pdl_context.py b/src/pdl/pdl_context.py index 941df2abd..be9648f02 100644 --- a/src/pdl/pdl_context.py +++ b/src/pdl/pdl_context.py @@ -40,8 +40,11 @@ def __mul__(self, value: "PDLContext"): class SingletonContext(PDLContext): message: PdlLazy[dict[str, Any]] - def __init__(self, message: PdlLazy[dict[str, Any]]): - self.message = message + def __init__(self, message: PdlLazy[dict[str, Any]] | dict[str, Any]): + if isinstance(message, PdlLazy): + self.message = message + else: + self.message = PdlConst(message) def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]: result = self.message.result() @@ -60,20 +63,23 @@ def __repr__(self): # pyright: ignore class IndependentContext(PDLContext): context: PdlLazy[list[PDLContext]] - def __init__(self, context: list[PDLContext]): + def __init__(self, context: list[PDLContext | dict[str, Any]]): ret: list[PDLContext] = [] for item in context: - if isinstance(item, IndependentContext): - ret += item.context.data - elif isinstance(item, SingletonContext): - ret += [item] - elif isinstance(item, DependentContext) and len(item) == 0: - pass - else: - # Not all elements of the list are Independent, so return - self.context = PdlList(context) - return - # All elements of the list are Independent + match item: + case IndependentContext(): + ret += item.context.data + case SingletonContext(): + ret += [item] + case DependentContext(): + if len(item) == 0: + pass + else: + ret += [item] + case dict(): + ret += [SingletonContext(item)] + case _: + assert False self.context = PdlList(ret) def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]: @@ -99,20 +105,23 @@ def __repr__(self): # pyright: ignore class DependentContext(PDLContext): context: PdlLazy[list[PDLContext]] - def __init__(self, context: list[PDLContext]): + def __init__(self, context: list[PDLContext | dict[str, Any]]): ret: list[PDLContext] = [] for item in context: - if isinstance(item, DependentContext): - ret += item.context.data - elif isinstance(item, SingletonContext): - ret += [item] - elif isinstance(item, IndependentContext) and len(item) == 0: - pass - else: - # Not all elements of the list are Dependent, so return - self.context = PdlList(context) - return - # All elements of the list are Dependent + match item: + case DependentContext(): + ret += item.context.data + case SingletonContext(): + ret += [item] + case IndependentContext(): + if len(item) == 0: + pass + else: + ret += [item] + case dict(): + ret += [SingletonContext(item)] + case _: + assert False self.context = PdlList(ret) def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]: diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index 85fe87041..d5b0f76e5 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -485,7 +485,7 @@ def process_advanced_block( background = DependentContext([]) contribute_value, trace = process_contribute(trace, new_scope, loc) if contribute_value is not None: - background = DependentContext([contribute_value]) + background = DependentContext(contribute_value) return result, background, new_scope, trace