Skip to content

Commit b97dcf1

Browse files
authored
refine scf.for pretty printing (#183)
fix #177
1 parent 38fa727 commit b97dcf1

File tree

4 files changed

+64
-30
lines changed

4 files changed

+64
-30
lines changed

src/kirin/dialects/scf/stmts.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,27 @@ def verify(self) -> None:
148148
def print_impl(self, printer: Printer) -> None:
149149
printer.print_name(self)
150150
printer.plain_print(" ")
151+
block = self.body.blocks[0]
152+
printer.print(block.args[0])
153+
printer.plain_print(" in ", style="keyword")
151154
printer.print(self.iterable)
152-
printer.plain_print(" init ")
153-
printer.print_seq(self.initializers, delim=", ")
154-
printer.plain_print(" ")
155-
printer.print(self.body)
155+
with printer.indent():
156+
printer.print_newline()
157+
printer.plain_print("iter_args(")
158+
for idx, (arg, val) in enumerate(zip(block.args[1:], self.initializers)):
159+
printer.print(arg)
160+
printer.plain_print(" = ")
161+
printer.print(val)
162+
if idx < len(self.initializers) - 1:
163+
printer.plain_print(", ")
164+
printer.plain_print(") {")
165+
166+
with printer.align(printer.result_width(block.stmts)):
167+
for stmt in block.stmts:
168+
printer.print_newline()
169+
printer.print_stmt(stmt)
170+
printer.print_newline()
171+
printer.plain_print("}")
156172

157173

158174
@statement(dialect=dialect)

src/kirin/ir/nodes/block.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -439,26 +439,7 @@ def print_impl(self, printer: Printer) -> None:
439439
with printer.indent(increase=2, mark=False):
440440
for stmt in self.stmts:
441441
printer.print_newline()
442-
if stmt._results:
443-
result_str = printer.result_str(stmt._results)
444-
printer.plain_print(
445-
result_str.rjust(printer.state.result_width), " = "
446-
)
447-
elif printer.state.result_width:
448-
printer.plain_print(" " * printer.state.result_width, " ")
449-
with printer.indent(printer.state.result_width + 3, mark=True):
450-
printer.print(stmt)
451-
if printer.analysis and any(
452-
result in printer.analysis for result in stmt._results
453-
):
454-
with printer.rich(style="warning"):
455-
printer.plain_print(" # ---> ")
456-
printer.plain_print(
457-
", ".join(
458-
repr(printer.analysis[result])
459-
for result in stmt._results
460-
)
461-
)
442+
printer.print_stmt(stmt)
462443

463444
def typecheck(self) -> None:
464445
"""Checking the types of the Statments in the Block."""

src/kirin/ir/nodes/region.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,15 @@ def walk(
282282
for block in reversed(self.blocks) if reverse else self.blocks:
283283
yield from block.walk(reverse=reverse, region_first=region_first)
284284

285+
def stmts(self) -> Iterator[Statement]:
286+
"""Iterate over all the Statements in the Region. This does not walk into nested Regions.
287+
288+
Yields:
289+
Iterator[Statement]: An iterator that yield Statements of Blocks in the Region.
290+
"""
291+
for block in self.blocks:
292+
yield from block.stmts
293+
285294
def print_impl(self, printer: Printer) -> None:
286295
# populate block ids
287296
for block in self.blocks:
@@ -293,12 +302,7 @@ def print_impl(self, printer: Printer) -> None:
293302
printer.plain_print("}")
294303
return
295304

296-
result_width = 0
297-
for bb in self.blocks:
298-
for stmt in bb.stmts:
299-
result_width = max(result_width, len(printer.result_str(stmt._results)))
300-
301-
with printer.align(result_width):
305+
with printer.align(printer.result_width(self.stmts())):
302306
with printer.indent(increase=2, mark=True):
303307
printer.print_newline()
304308
for idx, bb in enumerate(self.blocks):

src/kirin/print/printer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,25 @@ def print_name(
163163
self.plain_print(".")
164164
self.plain_print(node.name)
165165

166+
def print_stmt(self, node: "ir.Statement"):
167+
if node._results:
168+
result_str = self.result_str(node._results)
169+
self.plain_print(result_str.rjust(self.state.result_width), " = ")
170+
elif self.state.result_width:
171+
self.plain_print(" " * self.state.result_width, " ")
172+
with self.indent(self.state.result_width + 3, mark=True):
173+
self.print(node)
174+
if self.analysis and any(
175+
result in self.analysis for result in node._results
176+
):
177+
with self.rich(style="warning"):
178+
self.plain_print(" # ---> ")
179+
self.plain_print(
180+
", ".join(
181+
repr(self.analysis[result]) for result in node._results
182+
)
183+
)
184+
166185
def print_dialect_path(
167186
self, node: Union["ir.Attribute", "ir.Statement"], prefix: str = ""
168187
) -> None:
@@ -286,6 +305,20 @@ def print_mapping(
286305
self.plain_print(f"{key}=")
287306
emit(value)
288307

308+
def result_width(self, stmts: Iterable["ir.Statement"]) -> int:
309+
"""return the maximum width of the result column for a sequence of statements.
310+
311+
Args:
312+
stmts(Iterable[ir.Statement]): sequence of statements
313+
314+
Returns:
315+
int: maximum width of the result column
316+
"""
317+
result_width = 0
318+
for stmt in stmts:
319+
result_width = max(result_width, len(self.result_str(stmt._results)))
320+
return result_width
321+
289322
@contextmanager
290323
def align(self, width: int):
291324
"""align the result column width, and restore it after the context.

0 commit comments

Comments
 (0)