@@ -145,6 +145,15 @@ def block(self, begin: str, end: str):
145145 self .level -= 1
146146 self .locals .append (self ._indent (end ) + "\n " )
147147
148+ @contextlib .contextmanager
149+ def comment_if_emitted (self , comment ):
150+ self .comment (comment )
151+ yield
152+ self .comment (comment )
153+ if self .locals [- 1 ] == self .locals [- 2 ]:
154+ self .locals .pop ()
155+ self .locals .pop ()
156+
148157 def get (self , value : ir .Value , default : Any = _UNSPECIFIED ):
149158 if default is _UNSPECIFIED :
150159 return self .env [value ]
@@ -358,6 +367,17 @@ def _print_op(ctx, op):
358367 return bin_op (ctx , "int" , "%" , * op .operands )
359368 case "arith.divsi" :
360369 return bin_op (ctx , "int" , "/" , * op .operands )
370+ case "arith.andi" :
371+ return bin_op (ctx , _model_type (op .result .type ), "&" , * op .operands )
372+ case "arith.select" :
373+ cond , if_true , if_false = map (lambda o : ctx .get (o , None ), op .operands )
374+ if cond is None or if_true is None or if_false is None :
375+ return NotImplemented
376+ result_ty = _model_type (op .result .type )
377+ return ctx .emit (result_ty , f"({ cond } -> { if_true } : { if_false } )" )
378+ case "arith.index_cast" :
379+ model = ctx .get (op .operands [0 ], None )
380+ return ctx .emit ("int" , model ) if model is not None else NotImplemented
361381 case "arith.cmpi" :
362382 match op .predicate .value :
363383 case arith .CmpIPredicate .eq :
@@ -386,12 +406,44 @@ def _print_op(ctx, op):
386406 read_refs .append (model )
387407 with ctx .block ("d_step {" , "}" ): # Start reading
388408 for r in read_refs :
409+ for loc in r .written_at (None ):
410+ ctx .emit (None , f"assert(!{ loc } )" )
389411 for loc in r .readers_at (None ):
390412 ctx .emit (None , f"{ loc } ++" )
391413 with ctx .block ("d_step {" , "}" ): # Stop reading
392414 for r in read_refs :
393415 for loc in r .readers_at (None ):
394416 ctx .emit (None , f"{ loc } --" )
417+ case "vector.load" :
418+ ref = ctx .get (op .operands [0 ])
419+ assert isinstance (ref , GlobalRefModel )
420+ if (first_idx := ctx .get (op .operands [1 ], None )) is not None :
421+ leading_load_len = ir .VectorType (op .result .type ).shape [0 ]
422+ ref = GlobalRefModel (f"{ ref .base } + { first_idx } " , leading_load_len )
423+ with ctx .block ("d_step {" , "}" ): # Start reading
424+ for loc in ref .written_at (None ):
425+ ctx .emit (None , f"assert(!{ loc } )" )
426+ for loc in ref .readers_at (None ):
427+ ctx .emit (None , f"{ loc } ++" )
428+ with ctx .block ("d_step {" , "}" ): # Stop reading
429+ for loc in ref .readers_at (None ):
430+ ctx .emit (None , f"{ loc } --" )
431+ return NotImplemented # We don't model the result of the load.
432+ case "vector.store" :
433+ ref = ctx .get (op .operands [1 ]) # Stored value goes first
434+ assert isinstance (ref , GlobalRefModel )
435+ if (first_idx := ctx .get (op .operands [2 ], None )) is not None :
436+ leading_store_len = ir .VectorType (op .operands [0 ].type ).shape [0 ]
437+ ref = GlobalRefModel (f"{ ref .base } + { first_idx } " , leading_store_len )
438+ with ctx .block ("d_step {" , "}" ): # Start writing
439+ for loc in ref .readers_at (None ):
440+ ctx .emit (None , f"assert(!{ loc } )" )
441+ for loc in ref .written_at (None ):
442+ ctx .emit (None , f"assert(!{ loc } )" )
443+ ctx .emit (None , f"{ loc } = 1" )
444+ with ctx .block ("d_step {" , "}" ): # Stop reading
445+ for loc in ref .written_at (None ):
446+ ctx .emit (None , f"{ loc } = 0" )
395447 case "scf.for" :
396448 carrys = [
397449 ctx .emit ("int" , ctx .get (arg ))
@@ -419,6 +471,7 @@ def _print_op(ctx, op):
419471 ctx .emit (None , f"{ c } = { ctx .get (new )} " )
420472 ctx .emit (None , f"{ induction_var } = { induction_var } + { step } " )
421473 ctx .emit (None , ":: else -> break" )
474+ ctx .emit (None , "skip" ) # To avoid "Jump into d_step sequence errors"
422475 if len (carrys ) == 1 :
423476 return carrys [0 ]
424477 else :
@@ -450,16 +503,27 @@ def bin_op(ctx, result_ty, op, lhs, rhs):
450503 return ctx .emit (result_ty , f"{ lhs } { op } { rhs } " )
451504
452505
506+ def _model_type (ty ):
507+ if ir .IntegerType .isinstance (ty ):
508+ if ir .IntegerType (ty ).width == 1 :
509+ return "bool"
510+ else :
511+ return "int"
512+ else :
513+ raise NotImplementedError (ty )
514+
515+
453516def _print_block (ctx , block ):
454517 for op in block :
455518 try :
456- results = _print_op (ctx , op )
519+ with ctx .comment_if_emitted (op .OPERATION_NAME ):
520+ results = _print_op (ctx , op )
457521 except Exception as e :
458522 raise RuntimeError (f"Failed to print op: { op } " ) from e
459523 if results is NotImplemented :
460524 continue
461525 if not op .results :
462- assert results is None
526+ assert results is None or results == ()
463527 elif len (op .results ) > 1 :
464528 raise NotImplementedError (op )
465529 else :
0 commit comments