99from ..utils import bits_for
1010from .. import tracer
1111from ._ast import *
12+ from ._ast import _StatementList , _LateBoundStatement , Property
1213from ._ir import *
1314from ._cd import *
1415from ._xfrm import *
@@ -146,16 +147,51 @@ def helper(*args, **kwds):
146147 return decorator
147148
148149
150+ class FSMNextStatement (_LateBoundStatement ):
151+ def __init__ (self , ctrl_data , state , * , src_loc_at = 0 ):
152+ self .ctrl_data = ctrl_data
153+ self .state = state
154+ super ().__init__ (src_loc_at = 1 + src_loc_at )
155+
156+ def resolve (self ):
157+ return self .ctrl_data ["signal" ].eq (self .ctrl_data ["encoding" ][self .state ])
158+
159+
149160class FSM :
150- def __init__ (self , state , encoding , decoding ):
151- self .state = state
152- self .encoding = encoding
153- self .decoding = decoding
161+ def __init__ (self , data ):
162+ self ._data = data
163+ self .encoding = data [ " encoding" ]
164+ self .decoding = data [ " decoding" ]
154165
155166 def ongoing (self , name ):
156167 if name not in self .encoding :
157168 self .encoding [name ] = len (self .encoding )
158- return Operator ("==" , [self .state , self .encoding [name ]], src_loc_at = 0 )
169+ fsm_name = self ._data ["name" ]
170+ self ._data ["ongoing" ][name ] = Signal (name = f"{ fsm_name } _ongoing_{ name } " )
171+ return self ._data ["ongoing" ][name ]
172+
173+
174+ def resolve_statement (stmt ):
175+ if isinstance (stmt , _LateBoundStatement ):
176+ return resolve_statement (stmt .resolve ())
177+ elif isinstance (stmt , Switch ):
178+ return Switch (
179+ test = stmt .test ,
180+ cases = OrderedDict (
181+ (patterns , resolve_statements (stmts ))
182+ for patterns , stmts in stmt .cases .items ()
183+ ),
184+ src_loc = stmt .src_loc ,
185+ case_src_locs = stmt .case_src_locs ,
186+ )
187+ elif isinstance (stmt , (Assign , Property )):
188+ return stmt
189+ else :
190+ assert False # :nocov:
191+
192+
193+ def resolve_statements (stmts ):
194+ return _StatementList (resolve_statement (stmt ) for stmt in stmts )
159195
160196
161197class Module (_ModuleBuilderRoot , Elaboratable ):
@@ -172,6 +208,7 @@ def __init__(self):
172208 self ._statements = {}
173209 self ._ctrl_context = None
174210 self ._ctrl_stack = []
211+ self ._top_comb_statements = _StatementList ()
175212
176213 self ._driving = SignalDict ()
177214 self ._named_submodules = {}
@@ -391,17 +428,16 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
391428 init = reset
392429 fsm_data = self ._set_ctrl ("FSM" , {
393430 "name" : name ,
394- "signal" : Signal (name = f"{ name } _state" , src_loc_at = 2 ),
395431 "init" : init ,
396432 "domain" : domain ,
397433 "encoding" : OrderedDict (),
398434 "decoding" : OrderedDict (),
435+ "ongoing" : {},
399436 "states" : OrderedDict (),
400437 "src_loc" : tracer .get_src_loc (src_loc_at = 1 ),
401438 "state_src_locs" : {},
402439 })
403- self ._generated [name ] = fsm = \
404- FSM (fsm_data ["signal" ], fsm_data ["encoding" ], fsm_data ["decoding" ])
440+ self ._generated [name ] = fsm = FSM (fsm_data )
405441 try :
406442 self ._ctrl_context = "FSM"
407443 self .domain ._depth += 1
@@ -414,6 +450,7 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
414450 self .domain ._depth -= 1
415451 self ._ctrl_context = None
416452 self ._pop_ctrl ()
453+ fsm .state = fsm_data ["signal" ]
417454
418455 @contextmanager
419456 def State (self , name ):
@@ -423,7 +460,9 @@ def State(self, name):
423460 if name in fsm_data ["states" ]:
424461 raise NameError (f"FSM state '{ name } ' is already defined" )
425462 if name not in fsm_data ["encoding" ]:
463+ fsm_name = fsm_data ["name" ]
426464 fsm_data ["encoding" ][name ] = len (fsm_data ["encoding" ])
465+ fsm_data ["ongoing" ][name ] = Signal (name = f"{ fsm_name } _ongoing_{ name } " )
427466 try :
428467 _outer_case , self ._statements = self ._statements , {}
429468 self ._ctrl_context = None
@@ -445,9 +484,11 @@ def next(self, name):
445484 for level , (ctrl_name , ctrl_data ) in enumerate (reversed (self ._ctrl_stack )):
446485 if ctrl_name == "FSM" :
447486 if name not in ctrl_data ["encoding" ]:
487+ fsm_name = ctrl_data ["name" ]
448488 ctrl_data ["encoding" ][name ] = len (ctrl_data ["encoding" ])
489+ ctrl_data ["ongoing" ][name ] = Signal (name = f"{ fsm_name } _ongoing_{ name } " )
449490 self ._add_statement (
450- assigns = [ctrl_data [ "signal" ]. eq (ctrl_data [ "encoding" ][ name ] )],
491+ assigns = [FSMNextStatement (ctrl_data , name )],
451492 domain = ctrl_data ["domain" ],
452493 depth = len (self ._ctrl_stack ))
453494 return
@@ -500,19 +541,25 @@ def _pop_ctrl(self):
500541 src_loc = src_loc , case_src_locs = switch_case_src_locs ))
501542
502543 if name == "FSM" :
503- fsm_signal , fsm_init , fsm_encoding , fsm_decoding , fsm_states = \
504- data ["signal " ], data ["init" ], data ["encoding" ], data ["decoding" ], data ["states" ]
544+ fsm_name , fsm_init , fsm_encoding , fsm_decoding , fsm_states , fsm_ongoing = \
545+ data ["name " ], data ["init" ], data ["encoding" ], data ["decoding" ], data ["states" ], data [ "ongoing " ]
505546 fsm_state_src_locs = data ["state_src_locs" ]
506547 if not fsm_states :
548+ data ["signal" ] = Signal (0 , name = f"{ fsm_name } _state" , src_loc_at = 2 )
507549 return
508- fsm_signal .width = bits_for (len (fsm_encoding ) - 1 )
509550 if fsm_init is None :
510- fsm_signal . init = fsm_encoding [next (iter (fsm_states ))]
551+ init = fsm_encoding [next (iter (fsm_states ))]
511552 else :
512- fsm_signal . init = fsm_encoding [fsm_init ]
553+ init = fsm_encoding [fsm_init ]
513554 # The FSM is encoded such that the state with encoding 0 is always the init state.
514555 fsm_decoding .update ((n , s ) for s , n in fsm_encoding .items ())
515- fsm_signal .decoder = lambda n : f"{ fsm_decoding [n ]} /{ n } "
556+ data ["signal" ] = fsm_signal = Signal (range (len (fsm_encoding )), init = init ,
557+ name = f"{ fsm_name } _state" , src_loc_at = 2 ,
558+ decoder = lambda n : f"{ fsm_decoding [n ]} /{ n } " )
559+
560+ for name , sig in fsm_ongoing .items ():
561+ self ._top_comb_statements .append (
562+ sig .eq (Operator ("==" , [fsm_signal , fsm_encoding [name ]], src_loc_at = 0 )))
516563
517564 domains = set ()
518565 for stmts in fsm_states .values ():
@@ -533,20 +580,21 @@ def _add_statement(self, assigns, domain, depth):
533580 self ._pop_ctrl ()
534581
535582 for stmt in Statement .cast (assigns ):
536- if not isinstance (stmt , (Assign , Property )):
583+ if not isinstance (stmt , (Assign , Property , _LateBoundStatement )):
537584 raise SyntaxError (
538585 f"Only assignments and property checks may be appended to d.{ domain } " )
539586
540587 stmt ._MustUse__used = True
541588
542- for signal in stmt ._lhs_signals ():
543- if signal not in self ._driving :
544- self ._driving [signal ] = domain
545- elif self ._driving [signal ] != domain :
546- cd_curr = self ._driving [signal ]
547- raise SyntaxError (
548- f"Driver-driver conflict: trying to drive { signal !r} from d.{ domain } , but it is "
549- f"already driven from d.{ cd_curr } " )
589+ if isinstance (stmt , Assign ):
590+ for signal in stmt ._lhs_signals ():
591+ if signal not in self ._driving :
592+ self ._driving [signal ] = domain
593+ elif self ._driving [signal ] != domain :
594+ cd_curr = self ._driving [signal ]
595+ raise SyntaxError (
596+ f"Driver-driver conflict: trying to drive { signal !r} from d.{ domain } , but it is "
597+ f"already driven from d.{ cd_curr } " )
550598
551599 self ._statements .setdefault (domain , []).append (stmt )
552600
@@ -586,9 +634,13 @@ def elaborate(self, platform):
586634 for submodule , src_loc in self ._anon_submodules :
587635 fragment .add_subfragment (Fragment .get (submodule , platform ), None , src_loc = src_loc )
588636 for domain , statements in self ._statements .items ():
637+ statements = resolve_statements (statements )
589638 fragment .add_statements (domain , statements )
590- for signal , domain in self ._driving .items ():
591- fragment .add_driver (signal , domain )
639+ for signal in statements ._lhs_signals ():
640+ fragment .add_driver (signal , domain )
641+ fragment .add_statements ("comb" , self ._top_comb_statements )
642+ for signal in self ._top_comb_statements ._lhs_signals ():
643+ fragment .add_driver (signal , "comb" )
592644 fragment .add_domains (self ._domains .values ())
593645 fragment .generated .update (self ._generated )
594646 return fragment
0 commit comments