Skip to content

Commit 0805af3

Browse files
chunnienccopybara-github
authored andcommitted
Add progress tracking lib
PiperOrigin-RevId: 871707321
1 parent 1d20114 commit 0805af3

7 files changed

Lines changed: 154 additions & 28 deletions

File tree

litert_torch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
1516
from litert_torch._config import config
1617
from litert_torch._convert.interface import convert
1718
from litert_torch._convert.interface import experimental_add_compilation_backend

litert_torch/_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,14 @@ def lazy_constant_getter_chunk_size(self) -> int:
9393
def lazy_constant_getter_chunk_size(self, value: int) -> None:
9494
os.environ["LAZY_CONSTANT_GETTER_CHUNK_SIZE"] = str(value)
9595

96+
@property
97+
def show_progress(self) -> bool:
98+
"""True if the progress should be shown."""
99+
return _get_bool_env_var("LITERT_TORCH_SHOW_PROGRESS", default=True)
100+
101+
@show_progress.setter
102+
def show_progress(self, value: bool):
103+
os.environ["LITERT_TORCH_SHOW_PROGRESS"] = "y" if value else "n"
104+
96105

97106
config = _Config()

litert_torch/_convert/core.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from litert_torch import fx_infra
2222
from litert_torch import model
23+
from litert_torch import progress
2324
from litert_torch._convert import fx_passes
2425
from litert_torch._convert import litert_converter
2526
from litert_torch._convert import signature
@@ -70,6 +71,7 @@ def _warn_training_modules(signatures: list[signature.Signature]):
7071
logging.warning(message)
7172

7273

74+
@progress.task("LiteRT-Torch Convert")
7375
def convert_signatures(
7476
signatures: list[signature.Signature],
7577
*,
@@ -127,18 +129,20 @@ def export(**kwargs):
127129
)
128130
return exported_program
129131

130-
exported_programs = [
131-
export(
132+
exported_programs = []
133+
for sig in signatures:
134+
with progress.task(f"Torch Export: {sig.name}"):
135+
exported_program = export(
132136
mod=sig.module,
133137
args=sig.args,
134138
kwargs=sig.kwargs,
135139
dynamic_shapes=sig.dynamic_shapes,
136140
)
137-
for sig in signatures
138-
]
141+
exported_programs.append(exported_program)
139142

140143
# Apply default fx passes
141-
exported_programs = list(map(_run_convert_passes, exported_programs))
144+
with progress.task("Run FX Passes"):
145+
exported_programs = list(map(_run_convert_passes, exported_programs))
142146

143147
exporter = litert_converter.exported_programs_to_flatbuffer(
144148
exported_programs,

litert_torch/_convert/litert_converter.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from litert_torch import backend
2020
from litert_torch import model as model_lib
21+
from litert_torch import progress
2122
from litert_torch._convert import signature
2223
from litert_torch.backend import inline_consts as inline_consts_lib
2324
from litert_torch.quantize import quant_config as qcfg
@@ -75,16 +76,20 @@ def to_file(self, path: str):
7576
f.write(self.content)
7677
return
7778

78-
try:
79-
# TODO b/478909085 - Remove the try-except once converter_api_ext is
80-
# stable in OSS.
81-
converter_api_ext.export_flatbuffer_to_file(self._module_op, path)
82-
except TypeError:
83-
converter_api_ext.export_flatbuffer_to_file(self.module, path)
79+
with progress.task(f"Write Model to {path}"):
80+
try:
81+
# TODO b/478909085 - Remove the try-except once converter_api_ext is
82+
# stable in OSS.
83+
converter_api_ext.export_flatbuffer_to_file(self._module_op, path)
84+
except TypeError:
85+
converter_api_ext.export_flatbuffer_to_file(self.module, path)
8486

8587
def to_bytes(self) -> bytes:
8688
"""Returns the flatbuffer bytes of the module."""
87-
if self.content is None:
89+
if self.content is not None:
90+
return self.content
91+
92+
with progress.task("Write Model to Bytes"):
8893
try:
8994
# TODO b/478909085 - Remove the try-except once converter_api_ext is
9095
# stable in OSS.
@@ -93,8 +98,8 @@ def to_bytes(self) -> bytes:
9398
)
9499
except TypeError:
95100
self.content = converter_api_ext.export_flatbuffer_to_bytes(self.module)
96-
self.module = None
97101

102+
self.module = None
98103
return self.content
99104

100105

@@ -114,19 +119,19 @@ def exported_programs_to_flatbuffer(
114119
)
115120

116121
ir_context = backend.export_utils.create_ir_context()
117-
118122
cross_program_inline_consts_ctx = inline_consts_lib.InlineConstsContext(
119123
enable_lazy_constants=lightweight_conversion,
120124
)
121125

122126
lowered_programs = []
123127
for exported_program, sig in zip(exported_programs, signatures):
124128
# Convert ExportedProgram to Mlir Module.
125-
lowered = backend.export.exported_program_to_mlir(
126-
exported_program,
127-
ir_context=ir_context,
128-
lowering_context_plugins=[cross_program_inline_consts_ctx],
129-
)
129+
with progress.task(f"Lower to MLIR: {sig.name}"):
130+
lowered = backend.export.exported_program_to_mlir(
131+
exported_program,
132+
ir_context=ir_context,
133+
lowering_context_plugins=[cross_program_inline_consts_ctx],
134+
)
130135

131136
# Set signature.
132137
sig_name = sig.name
@@ -141,9 +146,10 @@ def exported_programs_to_flatbuffer(
141146
lowered_programs.append(lowered)
142147

143148
# Merge all lowered modules into one module.
144-
merged_module = converter_api_ext.merge_modules(
145-
[lowered.module for lowered in lowered_programs]
146-
)
149+
with progress.task("Merge MLIR Modules"):
150+
merged_module = converter_api_ext.merge_modules(
151+
[lowered.module for lowered in lowered_programs]
152+
)
147153

148154
# Prepare ai-edge-quantizer recipe.
149155
translated_recipe = None
@@ -170,7 +176,7 @@ def exported_programs_to_flatbuffer(
170176
config.canonicalizing_inf_as_min_max_float = False
171177

172178
# Run LiteRT converter passes.
173-
with ir_context:
179+
with ir_context, progress.task("Run LiteRT Converter Passes"):
174180
pass_manager = passmanager.PassManager()
175181
converter_api_ext.run_convert_to_tfl_passes(
176182
merged_module, pass_manager, config
@@ -181,9 +187,10 @@ def exported_programs_to_flatbuffer(
181187

182188
# Quantize the model if needed.
183189
if translated_recipe:
184-
model_bytes = translate_recipe.quantize_model(
185-
exporter.to_bytes(), translated_recipe
186-
)
187-
exporter = LazyModelExporter(content=model_bytes)
190+
with progress.task("Quantize Model"):
191+
model_bytes = translate_recipe.quantize_model(
192+
exporter.to_bytes(), translated_recipe
193+
)
194+
exporter = LazyModelExporter(content=model_bytes)
188195

189196
return exporter

litert_torch/backend/export.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Any, Callable, Optional
2222

2323
from litert_torch import fx_infra
24+
from litert_torch import progress
2425
from litert_torch.backend import _torch_future
2526
from litert_torch.backend import debuginfo
2627
from litert_torch.backend import export_utils
@@ -375,7 +376,7 @@ def exported_program_to_mlir(
375376
if not ir_context:
376377
ir_context = export_utils.create_ir_context()
377378

378-
with ir_context, ir.Location.unknown():
379+
with ir_context, ir.Location.unknown(), progress.task("Create MLIR Module"):
379380

380381
module = ir.Module.create()
381382
lctx = LoweringContext(ir_context, module)

litert_torch/fx_infra/_safe_run_decompositions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
16+
1617
import operator
1718
from typing import Any, Callable
19+
from litert_torch import progress
1820
import torch
1921

2022

@@ -69,6 +71,7 @@ def annotate_force_decomp(decomp: Callable[..., Any]):
6971
return decomp
7072

7173

74+
@progress.task("ExportedProgram Run Decompositions")
7275
def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
7376
"""Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
7477

litert_torch/progress.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2024 The LiteRT Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Progress tracking and logging library."""
17+
18+
import contextlib
19+
import dataclasses
20+
import time
21+
from typing import Generator
22+
from litert_torch import _config
23+
from rich import console
24+
25+
__all__ = ["task", "console"]
26+
27+
config = _config.config
28+
console = console.Console(color_system="auto")
29+
30+
31+
@dataclasses.dataclass
32+
class Task:
33+
name: str
34+
start_time: float = 0.0
35+
36+
37+
_task_stack: list[Task] = []
38+
39+
40+
def _fmt_elapsed_time(elapsed_time: float) -> str:
41+
minutes, seconds = divmod(elapsed_time, 60)
42+
return f"{int(minutes):02d}:{int(seconds):02d}"
43+
44+
45+
def _style_elapsed_time(elapsed: float) -> str:
46+
formatted = _fmt_elapsed_time(elapsed)
47+
if elapsed < 5:
48+
return f"[dim default]{formatted}[/dim default]"
49+
elif elapsed < 60:
50+
return f"[yellow]{formatted}[/yellow]"
51+
elif elapsed < 60 * 10:
52+
return f"[bold yellow]{formatted}[/bold yellow]"
53+
else:
54+
return f"[bold red]{formatted}[/bold red]"
55+
56+
57+
def _task_stack_repr() -> str:
58+
return " > ".join(
59+
f"[bold default]{task.name}[/bold default]" for task in _task_stack
60+
)
61+
62+
63+
def _stack_elapsed_time() -> str:
64+
elapsed = (
65+
0 if not _task_stack else time.perf_counter() - _task_stack[0].start_time
66+
)
67+
return f"[dim default]({_fmt_elapsed_time(elapsed)})[/dim default]"
68+
69+
70+
@contextlib.contextmanager
71+
def task(name: str) -> Generator[None, None, None]:
72+
"""Context manager for tracking one task."""
73+
if not config.show_progress:
74+
yield
75+
return
76+
77+
current_task = Task(name, time.perf_counter())
78+
_task_stack.append(current_task)
79+
80+
stack_view = _task_stack_repr()
81+
console.print(
82+
f"{_stack_elapsed_time()} [bold cyan][START][/bold cyan] {stack_view}"
83+
)
84+
85+
try:
86+
yield
87+
except Exception:
88+
console.print(
89+
f"{_stack_elapsed_time()} [bold red][ FAIL][/bold red] {stack_view}"
90+
)
91+
raise
92+
else:
93+
elapsed_time = time.perf_counter() - current_task.start_time
94+
console.print(
95+
f"{_stack_elapsed_time()} [bold green][ DONE][/bold green]"
96+
f" [dim]{stack_view}[/dim]"
97+
f" [dim](+[/dim]{_style_elapsed_time(elapsed_time)}[dim])[/dim]"
98+
)
99+
finally:
100+
if _task_stack:
101+
_task_stack.pop()

0 commit comments

Comments
 (0)