22"""
33
44from enum import StrEnum
5- from typing import Any , Literal , Mapping , Optional , Sequence , TypeAlias , Union
5+ from typing import (
6+ Any ,
7+ Generic ,
8+ Literal ,
9+ Mapping ,
10+ Optional ,
11+ Sequence ,
12+ TypeAlias ,
13+ TypeVar ,
14+ Union ,
15+ )
616
717from pydantic import BaseModel , ConfigDict , Field , RootModel
818from pydantic .json_schema import SkipJsonSchema
@@ -51,26 +61,24 @@ class LocationType(BaseModel):
5161empty_block_location = LocationType (file = "" , path = [], table = {})
5262
5363
54- class LocalizedExpression (BaseModel ):
64+ LocalizedExpressionT = TypeVar ("LocalizedExpressionT" )
65+
66+
67+ class LocalizedExpression (BaseModel , Generic [LocalizedExpressionT ]):
5568 """Expression with location information"""
5669
5770 model_config = ConfigDict (
58- extra = "forbid" , use_attribute_docstrings = True , arbitrary_types_allowed = True
71+ extra = "forbid" ,
72+ use_attribute_docstrings = True ,
73+ arbitrary_types_allowed = True ,
74+ model_title_generator = (lambda _ : "LocalizedExpression" ),
5975 )
60- expr : Any
76+ expr : LocalizedExpressionT
6177 location : Optional [LocationType ] = None
6278
6379
64- ExpressionType : TypeAlias = Any | LocalizedExpression
65- # (
66- # str
67- # | int
68- # | float
69- # | bool
70- # | None
71- # | list["ExpressionType"]
72- # | dict[str, "ExpressionType"]
73- # )
80+ ExpressionTypeT = TypeVar ("ExpressionTypeT" )
81+ ExpressionType : TypeAlias = ExpressionTypeT | str | LocalizedExpression [ExpressionTypeT ]
7482
7583
7684class Pattern (BaseModel ):
@@ -316,7 +324,7 @@ class ModelPlatform(StrEnum):
316324
317325class ModelBlock (Block ):
318326 kind : Literal [BlockKind .MODEL ] = BlockKind .MODEL
319- model : str | ExpressionType
327+ model : ExpressionType [ str ]
320328 input : Optional ["BlockType" ] = None
321329 trace : Optional ["BlockType" ] = None
322330 modelResponse : Optional [str ] = None
@@ -326,7 +334,7 @@ class LitellmModelBlock(ModelBlock):
326334 """Call a LLM through the LiteLLM API: https://docs.litellm.ai/."""
327335
328336 platform : Literal [ModelPlatform .LITELLM ] = ModelPlatform .LITELLM
329- parameters : Optional [LitellmParameters | ExpressionType ] = None
337+ parameters : Optional [LitellmParameters | ExpressionType [ dict ] ] = None
330338
331339
332340class CodeBlock (Block ):
@@ -353,7 +361,7 @@ class DataBlock(Block):
353361 """Arbitrary JSON value."""
354362
355363 kind : Literal [BlockKind .DATA ] = BlockKind .DATA
356- data : ExpressionType
364+ data : ExpressionType [ Any ]
357365 """Value defined."""
358366 raw : bool = False
359367 """Do not evaluate expressions inside strings."""
@@ -403,7 +411,7 @@ class IfBlock(Block):
403411 """Conditional control structure."""
404412
405413 kind : Literal [BlockKind .IF ] = BlockKind .IF
406- condition : ExpressionType = Field (alias = "if" )
414+ condition : ExpressionType [ bool ] = Field (alias = "if" )
407415 """Condition.
408416 """
409417 then : "BlockType"
@@ -423,7 +431,7 @@ class MatchCase(BaseModel):
423431 case : Optional [PatternType ] = None
424432 """Value to match.
425433 """
426- if_ : Optional [ExpressionType ] = Field (default = None , alias = "if" )
434+ if_ : Optional [ExpressionType [ bool ] ] = Field (default = None , alias = "if" )
427435 """Boolean condition to satisfy.
428436 """
429437 then : "BlockType"
@@ -435,7 +443,7 @@ class MatchBlock(Block):
435443 """Match control structure."""
436444
437445 kind : Literal [BlockKind .MATCH ] = BlockKind .MATCH
438- match_ : ExpressionType = Field (alias = "match" )
446+ match_ : ExpressionType [ Any ] = Field (alias = "match" )
439447 """Matched expression.
440448 """
441449 with_ : list [MatchCase ] = Field (alias = "with" )
@@ -484,19 +492,19 @@ class RepeatBlock(Block):
484492 """Repeat the execution of a block."""
485493
486494 kind : Literal [BlockKind .REPEAT ] = BlockKind .REPEAT
487- fors : Optional [dict [str , ExpressionType ]] = Field (default = None , alias = "for" )
495+ fors : Optional [dict [str , ExpressionType [ list ] ]] = Field (default = None , alias = "for" )
488496 """Arrays to iterate over.
489497 """
490- while_ : ExpressionType = Field (default = True , alias = "while" )
498+ while_ : ExpressionType [ bool ] = Field (default = True , alias = "while" )
491499 """Condition to stay at the beginning of the loop.
492500 """
493501 repeat : "BlockType"
494502 """Body of the loop.
495503 """
496- until : ExpressionType = False
504+ until : ExpressionType [ bool ] = False
497505 """Condition to exit at the end of the loop.
498506 """
499- max_iterations : Optional [ExpressionType ] = None
507+ max_iterations : Optional [ExpressionType [ int ] ] = None
500508 """Maximal number of iterations to perform.
501509 """
502510 join : JoinType = JoinText ()
@@ -510,7 +518,7 @@ class ReadBlock(Block):
510518 """Read from a file or standard input."""
511519
512520 kind : Literal [BlockKind .READ ] = BlockKind .READ
513- read : ExpressionType | None
521+ read : ExpressionType [ str ] | None
514522 """Name of the file to read. If `None`, read the standard input.
515523 """
516524 message : Optional [str ] = None
0 commit comments