99from ..utils import bits_for
1010from .. import tracer
1111from ._ast import *
12+ from ._ast import _StatementList , _LateBoundStatement , Property
1213from ._ir import *
1314from ._cd import *
1415from ._xfrm import *
@@ -146,16 +147,50 @@ def helper(*args, **kwds):
146147 return decorator
147148
148149
150+ class FSMNextStatement (_LateBoundStatement ):
151+ def __init__ (self , ctrl_data , state , src_loc_at = 0 ):
152+ self .ctrl_data = ctrl_data
153+ self .state = state
154+ super ().__init__ (src_loc_at = 1 + src_loc_at )
155+
156+ def resolve (self ):
157+ return self .ctrl_data ["signal" ].eq (self .ctrl_data ["encoding" ][self .state ])
158+
159+
149160class FSM :
150- def __init__ (self , state , encoding , decoding ):
151- self .state = state
152- self .encoding = encoding
153- self .decoding = decoding
161+ def __init__ (self , data ):
162+ self ._data = data
163+ self .encoding = data [ " encoding" ]
164+ self .decoding = data [ " decoding" ]
154165
155166 def ongoing (self , name ):
156167 if name not in self .encoding :
157168 self .encoding [name ] = len (self .encoding )
158- return Operator ("==" , [self .state , self .encoding [name ]], src_loc_at = 0 )
169+ fsm_name = self ._data ["name" ]
170+ self ._data ["ongoing" ][name ] = Signal (name = f"{ fsm_name } _ongoing_{ name } " )
171+ return self ._data ["ongoing" ][name ]
172+
173+
174+ def resolve_statement (stmt ):
175+ if isinstance (stmt , _LateBoundStatement ):
176+ return resolve_statement (stmt .resolve ())
177+ elif isinstance (stmt , Switch ):
178+ return Switch (
179+ test = stmt .test ,
180+ cases = OrderedDict (
181+ (patterns , resolve_statements (stmts ))
182+ for patterns , stmts in stmt .cases .items ()
183+ ),
184+ src_loc = stmt .src_loc ,
185+ case_src_locs = stmt .case_src_locs ,
186+ )
187+ elif isinstance (stmt , (Assign , Property )):
188+ return stmt
189+ else :
190+ assert False # nocov
191+
192+ def resolve_statements (stmts ):
193+ return _StatementList (resolve_statement (stmt ) for stmt in stmts )
159194
160195
161196class Module (_ModuleBuilderRoot , Elaboratable ):
@@ -172,6 +207,7 @@ def __init__(self):
172207 self ._statements = {}
173208 self ._ctrl_context = None
174209 self ._ctrl_stack = []
210+ self ._top_comb_statements = _StatementList ()
175211
176212 self ._driving = SignalDict ()
177213 self ._named_submodules = {}
@@ -391,17 +427,16 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
391427 init = reset
392428 fsm_data = self ._set_ctrl ("FSM" , {
393429 "name" : name ,
394- "signal" : Signal (name = f"{ name } _state" , src_loc_at = 2 ),
395430 "init" : init ,
396431 "domain" : domain ,
397432 "encoding" : OrderedDict (),
398433 "decoding" : OrderedDict (),
434+ "ongoing" : {},
399435 "states" : OrderedDict (),
400436 "src_loc" : tracer .get_src_loc (src_loc_at = 1 ),
401437 "state_src_locs" : {},
402438 })
403- self ._generated [name ] = fsm = \
404- FSM (fsm_data ["signal" ], fsm_data ["encoding" ], fsm_data ["decoding" ])
439+ self ._generated [name ] = fsm = FSM (fsm_data )
405440 try :
406441 self ._ctrl_context = "FSM"
407442 self .domain ._depth += 1
@@ -414,6 +449,7 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
414449 self .domain ._depth -= 1
415450 self ._ctrl_context = None
416451 self ._pop_ctrl ()
452+ fsm .state = fsm_data ["signal" ]
417453
418454 @contextmanager
419455 def State (self , name ):
@@ -423,7 +459,9 @@ def State(self, name):
423459 if name in fsm_data ["states" ]:
424460 raise NameError (f"FSM state '{ name } ' is already defined" )
425461 if name not in fsm_data ["encoding" ]:
462+ fsm_name = fsm_data ["name" ]
426463 fsm_data ["encoding" ][name ] = len (fsm_data ["encoding" ])
464+ fsm_data ["ongoing" ][name ] = Signal (name = f"{ fsm_name } _ongoing_{ name } " )
427465 try :
428466 _outer_case , self ._statements = self ._statements , {}
429467 self ._ctrl_context = None
@@ -445,9 +483,11 @@ def next(self, name):
445483 for level , (ctrl_name , ctrl_data ) in enumerate (reversed (self ._ctrl_stack )):
446484 if ctrl_name == "FSM" :
447485 if name not in ctrl_data ["encoding" ]:
486+ fsm_name = ctrl_data ["name" ]
448487 ctrl_data ["encoding" ][name ] = len (ctrl_data ["encoding" ])
488+ ctrl_data ["ongoing" ][name ] = Signal (name = f"{ fsm_name } _ongoing_{ name } " )
449489 self ._add_statement (
450- assigns = [ctrl_data [ "signal" ]. eq (ctrl_data [ "encoding" ][ name ] )],
490+ assigns = [FSMNextStatement (ctrl_data , name )],
451491 domain = ctrl_data ["domain" ],
452492 depth = len (self ._ctrl_stack ))
453493 return
@@ -500,19 +540,25 @@ def _pop_ctrl(self):
500540 src_loc = src_loc , case_src_locs = switch_case_src_locs ))
501541
502542 if name == "FSM" :
503- fsm_signal , fsm_init , fsm_encoding , fsm_decoding , fsm_states = \
504- data ["signal " ], data ["init" ], data ["encoding" ], data ["decoding" ], data ["states" ]
543+ fsm_name , fsm_init , fsm_encoding , fsm_decoding , fsm_states , fsm_ongoing = \
544+ data ["name " ], data ["init" ], data ["encoding" ], data ["decoding" ], data ["states" ], data [ "ongoing " ]
505545 fsm_state_src_locs = data ["state_src_locs" ]
506546 if not fsm_states :
547+ data ["signal" ] = Signal (0 , name = f"{ fsm_name } _state" , src_loc_at = 2 )
507548 return
508- fsm_signal .width = bits_for (len (fsm_encoding ) - 1 )
509549 if fsm_init is None :
510- fsm_signal . init = fsm_encoding [next (iter (fsm_states ))]
550+ init = fsm_encoding [next (iter (fsm_states ))]
511551 else :
512- fsm_signal . init = fsm_encoding [fsm_init ]
552+ init = fsm_encoding [fsm_init ]
513553 # The FSM is encoded such that the state with encoding 0 is always the init state.
514554 fsm_decoding .update ((n , s ) for s , n in fsm_encoding .items ())
515- fsm_signal .decoder = lambda n : f"{ fsm_decoding [n ]} /{ n } "
555+ data ["signal" ] = fsm_signal = Signal (range (len (fsm_encoding )), init = init ,
556+ name = f"{ fsm_name } _state" , src_loc_at = 2 ,
557+ decoder = lambda n : f"{ fsm_decoding [n ]} /{ n } " )
558+
559+ for name , sig in fsm_ongoing .items ():
560+ self ._top_comb_statements .append (
561+ sig .eq (Operator ("==" , [fsm_signal , fsm_encoding [name ]], src_loc_at = 0 )))
516562
517563 domains = set ()
518564 for stmts in fsm_states .values ():
@@ -533,20 +579,21 @@ def _add_statement(self, assigns, domain, depth):
533579 self ._pop_ctrl ()
534580
535581 for stmt in Statement .cast (assigns ):
536- if not isinstance (stmt , (Assign , Property )):
582+ if not isinstance (stmt , (Assign , Property , _LateBoundStatement )):
537583 raise SyntaxError (
538584 f"Only assignments and property checks may be appended to d.{ domain } " )
539585
540586 stmt ._MustUse__used = True
541587
542- for signal in stmt ._lhs_signals ():
543- if signal not in self ._driving :
544- self ._driving [signal ] = domain
545- elif self ._driving [signal ] != domain :
546- cd_curr = self ._driving [signal ]
547- raise SyntaxError (
548- f"Driver-driver conflict: trying to drive { signal !r} from d.{ domain } , but it is "
549- f"already driven from d.{ cd_curr } " )
588+ if isinstance (stmt , Assign ):
589+ for signal in stmt ._lhs_signals ():
590+ if signal not in self ._driving :
591+ self ._driving [signal ] = domain
592+ elif self ._driving [signal ] != domain :
593+ cd_curr = self ._driving [signal ]
594+ raise SyntaxError (
595+ f"Driver-driver conflict: trying to drive { signal !r} from d.{ domain } , but it is "
596+ f"already driven from d.{ cd_curr } " )
550597
551598 self ._statements .setdefault (domain , []).append (stmt )
552599
@@ -586,9 +633,13 @@ def elaborate(self, platform):
586633 for submodule , src_loc in self ._anon_submodules :
587634 fragment .add_subfragment (Fragment .get (submodule , platform ), None , src_loc = src_loc )
588635 for domain , statements in self ._statements .items ():
636+ statements = resolve_statements (statements )
589637 fragment .add_statements (domain , statements )
590- for signal , domain in self ._driving .items ():
591- fragment .add_driver (signal , domain )
638+ for signal in statements ._lhs_signals ():
639+ fragment .add_driver (signal , domain )
640+ fragment .add_statements ("comb" , self ._top_comb_statements )
641+ for signal in self ._top_comb_statements ._lhs_signals ():
642+ fragment .add_driver (signal , "comb" )
592643 fragment .add_domains (self ._domains .values ())
593644 fragment .generated .update (self ._generated )
594645 return fragment
0 commit comments